Deprecate prof_dag (#11956)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11956
Deprecate prof_dag and redirect it to the unified executor
Reviewed By: aazzolini
Differential Revision: D9983992
fbshipit-source-id: 16821628a99a5683dc39cbb345ddab56e9d8721c
diff --git a/caffe2/contrib/prof/prof_dag_stats_op.cc b/caffe2/contrib/prof/prof_dag_stats_op.cc
index 70f4c73..684f764 100644
--- a/caffe2/contrib/prof/prof_dag_stats_op.cc
+++ b/caffe2/contrib/prof/prof_dag_stats_op.cc
@@ -18,9 +18,9 @@
"op will be calculated separately")
.Arg(
"partial_net_name",
- "(string) default to empty; describes the partial name of the ProfDAGNet")
+ "(string) default to empty; describes the partial name of the net")
.Arg(
"net_name",
- "(string) default to empty; describes the name of the ProfDAGNet");
+ "(string) default to empty; describes the name of the net");
} // namespace
} // namespace caffe2
diff --git a/caffe2/contrib/prof/prof_dag_stats_op.h b/caffe2/contrib/prof/prof_dag_stats_op.h
index 1b90781..578ec44 100644
--- a/caffe2/contrib/prof/prof_dag_stats_op.h
+++ b/caffe2/contrib/prof/prof_dag_stats_op.h
@@ -3,6 +3,7 @@
#include "caffe2/contrib/prof/prof_dag_net.h"
#include "caffe2/core/context.h"
+#include "caffe2/core/net_async_base.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
@@ -53,14 +54,16 @@
" as part of its name");
}
- auto prof_dag_net = dynamic_cast_if_rtti<ProfDAGNet*>(net);
- CAFFE_ENFORCE(prof_dag_net);
-
ProfDAGProtos stats;
- if (per_op_) {
- stats = prof_dag_net->GetPerOperatorCost();
+ auto async_net = dynamic_cast_if_rtti<AsyncNetBase*>(net);
+ if (async_net) {
+ LOG(INFO) << "Using AsyncNetBase to collect stats";
+ stats = getProtos(async_net);
} else {
- stats = prof_dag_net->GetOperatorStats();
+ auto prof_dag_net = dynamic_cast_if_rtti<ProfDAGNet*>(net);
+ CAFFE_ENFORCE(prof_dag_net);
+ LOG(INFO) << "Using ProfDAGNet to collect stats";
+ stats = getProtos(prof_dag_net);
}
// Write protobuf message to the output blob
@@ -72,6 +75,17 @@
return true;
}
+ template <typename Net>
+ ProfDAGProtos getProtos(Net* net) {
+ ProfDAGProtos stats;
+ if (per_op_) {
+ stats = net->GetPerOperatorCost();
+ } else {
+ stats = net->GetOperatorStats();
+ }
+ return stats;
+ }
+
protected:
std::string net_name_;
std::string partial_net_name_;
diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc
index 37f38ce..f51d336 100644
--- a/caffe2/core/net.cc
+++ b/caffe2/core/net.cc
@@ -107,6 +107,7 @@
const std::unordered_map<std::string, std::string>& defaultOverrides() {
static const std::unordered_map<std::string, std::string> overrides = {
{"dag", "async_scheduling"},
+ {"prof_dag", "async_scheduling"},
{"async_dag", "async_scheduling"},
{"async_polling", "async_scheduling"},
{"async_simple", "simple"},
diff --git a/caffe2/core/net_async_base.cc b/caffe2/core/net_async_base.cc
index ba5ca01..c2728c7af 100644
--- a/caffe2/core/net_async_base.cc
+++ b/caffe2/core/net_async_base.cc
@@ -68,7 +68,9 @@
AsyncNetBase::AsyncNetBase(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws)
- : NetBase(net_def, ws) {
+ : NetBase(net_def, ws), counters_(net_def) {
+ computeExecutionModeFlags();
+
operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(this);
operators_.reserve(operator_nodes_.size());
@@ -93,12 +95,15 @@
for (const auto& chain : chains_) {
const auto& last_op = operators_[chain.back()];
events_.push_back(&last_op->event());
- for (const auto& op_id : chain) {
- if (op_id == chain.back() || op_id == chain.front()) {
- continue;
+ // keep events for inner chain ops in case of profiling
+ if (!report_stats_) {
+ for (const auto& op_id : chain) {
+ if (op_id == chain.back() || op_id == chain.front()) {
+ continue;
+ }
+ const auto& op = operators_[op_id];
+ op->DisableEvent();
}
- const auto& op = operators_[op_id];
- op->DisableEvent();
}
}
@@ -108,8 +113,6 @@
if (tracer_) {
LOG(INFO) << "Tracing net: " << net_def->name();
}
-
- computeExecutionModeFlags();
}
bool AsyncNetBase::handleRunError() {
@@ -294,14 +297,28 @@
return chains_[task_id].size();
}
+int AsyncNetBase::firstTaskOpId(int task_id) const {
+ return chains_[task_id].front();
+}
+
+int AsyncNetBase::lastTaskOpId(int task_id) const {
+ return chains_[task_id].back();
+}
+
const OperatorBase* AsyncNetBase::firstTaskOp(int task_id) const {
- auto op_id = chains_[task_id].front();
- return operator_nodes_[op_id].operator_.get();
+ return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
}
const OperatorBase* AsyncNetBase::lastTaskOp(int task_id) const {
- auto op_id = chains_[task_id].back();
- return operator_nodes_[op_id].operator_.get();
+ return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
+}
+
+OperatorBase* AsyncNetBase::firstTaskOp(int task_id) {
+ return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
+}
+
+OperatorBase* AsyncNetBase::lastTaskOp(int task_id) {
+ return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
}
void AsyncNetBase::asyncWait(
@@ -364,14 +381,24 @@
}
for (auto& op_id : chains_[task_id]) {
op = operators_[op_id];
- TRACE_EVENT(
- tracing::TRACE_OP,
- op_id,
- tracing::TRACE_TASK,
- task_id,
- tracing::TRACE_STREAM,
- stream_id);
- bool success = op->RunAsync(stream_id);
+ bool success = false;
+ if (!report_stats_) {
+ TRACE_EVENT(
+ tracing::TRACE_OP,
+ op_id,
+ tracing::TRACE_TASK,
+ task_id,
+ tracing::TRACE_STREAM,
+ stream_id);
+ success = op->RunAsync(stream_id);
+ } else {
+ counters_.AddPerOpStartTime(op_id);
+ success = op->RunAsync(stream_id);
+ if (success && op->device_option().device_type() != PROTO_CPU) {
+ op->Finish();
+ }
+ counters_.AddPerOpEndTime(op_id);
+ }
if (!success) {
auto err_msg = "Failed to execute an op: " +
(op->has_debug_def() ? op->type() : " unknown");
@@ -425,7 +452,19 @@
}
}
-AsyncNetBase::~AsyncNetBase() {}
+ProfDAGProtos AsyncNetBase::GetOperatorStats() const {
+ return counters_.GetOperatorStats();
+}
+
+ProfDAGProtos AsyncNetBase::GetPerOperatorCost() const {
+ return counters_.GetPerOperatorCost();
+}
+
+AsyncNetBase::~AsyncNetBase() {
+ if (report_stats_) {
+ counters_.PrintStats();
+ }
+}
C10_DEFINE_SHARED_REGISTRY(
ThreadPoolRegistry,
@@ -459,6 +498,7 @@
use_single_pool_ = true;
use_per_net_pools_ = true;
is_blocking_ = true;
+ report_stats_ = (net_type == kProfDag);
} else if (net_type == kAsyncDag) {
streams_per_gpu_ = 1;
finish_chain_ = false;
@@ -467,6 +507,7 @@
use_single_pool_ = true;
use_per_net_pools_ = true;
is_blocking_ = true;
+ report_stats_ = false;
} else {
streams_per_gpu_ = c10::FLAGS_caffe2_streams_per_gpu;
finish_chain_ = c10::FLAGS_caffe2_net_async_finish_chain;
@@ -475,6 +516,16 @@
use_single_pool_ = c10::FLAGS_caffe2_net_async_use_single_pool;
use_per_net_pools_ = c10::FLAGS_caffe2_net_async_use_per_net_pools;
is_blocking_ = false;
+ report_stats_ = false;
+ }
+
+ for (int arg_idx = 0; arg_idx < net_def_->arg_size(); ++arg_idx) {
+ auto& arg = net_def_->arg(arg_idx);
+ if (arg.has_name() && arg.name() == "enable_profiling") {
+ CAFFE_ENFORCE(arg.has_i(), "enable_profiling should be an int");
+ report_stats_ = arg.i() == 1;
+ break;
+ }
}
}
diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h
index f8d61f7..8dccaf4 100644
--- a/caffe2/core/net_async_base.h
+++ b/caffe2/core/net_async_base.h
@@ -6,10 +6,12 @@
#include "caffe2/core/net.h"
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_dag_utils.h"
+#include "caffe2/core/prof_dag_counters.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/prof_dag.pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/thread_pool.h"
@@ -50,6 +52,9 @@
return execution_chains_;
}
+ ProfDAGProtos GetOperatorStats() const;
+ ProfDAGProtos GetPerOperatorCost() const;
+
protected:
bool canSchedule(
int chain_id,
@@ -66,8 +71,13 @@
int getParentCount(int child_id);
bool testAndSetScheduled(int task_id);
int numOps(int task_id) const;
+
+ int firstTaskOpId(int task_id) const;
+ int lastTaskOpId(int task_id) const;
const OperatorBase* firstTaskOp(int task_id) const;
const OperatorBase* lastTaskOp(int task_id) const;
+ OperatorBase* firstTaskOp(int task_id);
+ OperatorBase* lastTaskOp(int task_id);
void asyncWait(
int task_id,
@@ -126,6 +136,9 @@
bool use_single_pool_;
bool use_per_net_pools_;
bool is_blocking_;
+ bool report_stats_;
+
+ ProfDAGCounters counters_;
C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
diff --git a/caffe2/core/net_async_scheduling.cc b/caffe2/core/net_async_scheduling.cc
index 369a8e7..1db51f2 100644
--- a/caffe2/core/net_async_scheduling.cc
+++ b/caffe2/core/net_async_scheduling.cc
@@ -61,6 +61,16 @@
}
}
+ if (report_stats_) {
+ auto last_op_id = lastTaskOpId(task_id);
+ auto* last_op = lastTaskOp(task_id);
+ if (last_op->device_option().device_type() == PROTO_CPU &&
+ last_op->HasAsyncPart()) {
+ last_op->event().SetCallback(
+ [this, last_op_id] { counters_.AddPerOpAsyncEndTime(last_op_id); });
+ }
+ }
+
for (auto child_id : children(task_id)) {
int parent_count = updateParentCount(child_id);
if (parent_count == 0) {
@@ -209,6 +219,9 @@
std::unique_lock<std::mutex> lock(running_mutex_);
// wait for scheduled ops and make sure all events are marked as finished
finalizeEvents();
+ if (report_stats_) {
+ counters_.ReportRunEnd();
+ }
// notify observers and waiters
StopAllObservers();
running_ = false;
@@ -228,7 +241,11 @@
StartAllObservers();
tracing::startIter(tracer_);
+ if (report_stats_) {
+ counters_.ReportRunStart();
+ }
}
+
for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
if (parents(task_id).empty()) {
schedule(task_id);
diff --git a/caffe2/core/prof_dag_counters.cc b/caffe2/core/prof_dag_counters.cc
new file mode 100644
index 0000000..f256880
--- /dev/null
+++ b/caffe2/core/prof_dag_counters.cc
@@ -0,0 +1,156 @@
+#include "caffe2/core/prof_dag_counters.h"
+
+#include <ostream>
+#include <sstream>
+
+namespace caffe2 {
+
+ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def)
+ : net_name_(net_def->name()), num_runs_(0) {
+ op_types_.reserve(net_def->op_size());
+ for (auto op_id = 0; op_id < net_def->op_size(); ++op_id) {
+ op_types_.push_back(net_def->op(op_id).type());
+ }
+ time_per_op_total_.resize(op_types_.size());
+}
+
+void ProfDAGCounters::ReportRunStart() {
+ num_runs_ += 1;
+ timer_.Start();
+
+ op_start_times_run_.clear();
+ op_start_times_run_.resize(op_types_.size(), -1.0);
+ op_end_times_run_.clear();
+ op_end_times_run_.resize(op_types_.size(), -1.0);
+ op_async_end_times_run_.clear();
+ op_async_end_times_run_.resize(op_types_.size(), -1.0);
+}
+
+void ProfDAGCounters::AddPerOpStartTime(size_t op_id) {
+ if (num_runs_ <= 1) {
+ return;
+ }
+
+ CAFFE_ENFORCE(op_id >= 0 && op_id < op_start_times_run_.size());
+ op_start_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::AddPerOpEndTime(size_t op_id) {
+ if (num_runs_ <= 1) {
+ return;
+ }
+
+ CAFFE_ENFORCE(op_id >= 0 && op_id < op_end_times_run_.size());
+ op_end_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::AddPerOpAsyncEndTime(size_t op_id) {
+ if (num_runs_ <= 1) {
+ return;
+ }
+
+ CAFFE_ENFORCE(op_id >= 0 && op_id < op_async_end_times_run_.size());
+ op_async_end_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::ReportRunEnd() {
+ if (num_runs_ <= 1) {
+ return;
+ }
+
+ auto runtime = timer_.MilliSeconds();
+ runtime_stats_ += ProfDAGStats(runtime);
+
+ CaffeMap<std::string, float> cum_per_type_time_run_;
+ CaffeMap<std::string, float> cum_per_type_invocations_run_;
+ for (auto op_id = 0; op_id < op_types_.size(); ++op_id) {
+ float op_time;
+ CAFFE_ENFORCE(op_start_times_run_[op_id] > 0);
+ if (op_async_end_times_run_[op_id] > 0) {
+ auto op_async_time =
+ op_async_end_times_run_[op_id] - op_start_times_run_[op_id];
+ CAFFE_ENFORCE_GE(op_async_time, 0.0);
+ op_time = op_async_time;
+ } else {
+ auto op_sync_time = op_end_times_run_[op_id] - op_start_times_run_[op_id];
+ CAFFE_ENFORCE_GE(op_sync_time, 0.0);
+ op_time = op_sync_time;
+ }
+
+ time_per_op_total_[op_id] += ProfDAGStats(op_time);
+
+ const string& op_type = op_types_[op_id];
+ cum_per_type_time_run_[op_type] += op_time;
+ cum_per_type_invocations_run_[op_type] += 1;
+ }
+
+ for (const auto& kv : cum_per_type_time_run_) {
+ time_per_op_type_total_[kv.first] += ProfDAGStats(kv.second);
+ times_per_run_per_type_total_[kv.first] +=
+ ProfDAGStats(cum_per_type_invocations_run_[kv.first]);
+ }
+}
+
+ProfDAGProto ProfDAGCounters::statsProto(
+ const std::string& name,
+ const ProfDAGStats& stats) const {
+ ProfDAGProto stats_proto;
+ const auto& moments = stats.computeMoments();
+ stats_proto.set_mean(moments.first);
+ stats_proto.set_stddev(moments.second);
+ stats_proto.set_name(name);
+ return stats_proto;
+}
+
+ProfDAGProtos ProfDAGCounters::GetOperatorStats() const {
+ CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+ ProfDAGProtos prof_dag_protos;
+ for (auto& item : time_per_op_type_total_) {
+ auto buf = prof_dag_protos.add_stats();
+ buf->CopyFrom(statsProto(item.first, item.second));
+ }
+ return prof_dag_protos;
+}
+
+ProfDAGProtos ProfDAGCounters::GetPerOperatorCost() const {
+ CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+ ProfDAGProtos prof_dag_protos;
+ for (int op_id = 0; op_id < op_types_.size(); op_id++) {
+ const string& op_type = op_types_[op_id];
+ auto buf = prof_dag_protos.add_stats();
+ std::string op_output_name =
+ net_name_ + "___" + to_string(op_id) + "___" + op_type;
+ buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
+ }
+ return prof_dag_protos;
+}
+
+void ProfDAGCounters::PrintStats() {
+ if (num_runs_ <= 1) {
+ LOG(INFO) << "Insufficient number of runs";
+ return;
+ }
+
+ std::ostringstream debug_out;
+ debug_out << "Measured operators over " << num_runs_ << " net runs ("
+ << net_name_ << "), #ops: " << op_types_.size() << std::endl;
+
+ debug_out << "Mean time in operator type per run (stddev):" << std::endl;
+ for (const auto& item : time_per_op_type_total_) {
+ const auto& moments = item.second.computeMoments();
+ const auto& times_moments =
+ times_per_run_per_type_total_[item.first].computeMoments();
+ debug_out << std::setw(10) << std::setfill(' ') << moments.first
+ << " ms/run (" << std::setw(10) << std::setfill(' ')
+ << moments.second << " ms/run) "
+ << " Op count per run: " << times_moments.first << " "
+ << item.first << std::endl;
+ }
+ const auto& runtime_moments = runtime_stats_.computeMoments();
+ debug_out << net_name_ << " runtime: " << runtime_moments.first << " ms ("
+ << runtime_moments.second << " ms)" << std::endl;
+
+ LOG(INFO) << debug_out.str();
+}
+
+} // namespace caffe2
diff --git a/caffe2/core/prof_dag_counters.h b/caffe2/core/prof_dag_counters.h
new file mode 100644
index 0000000..49d7305
--- /dev/null
+++ b/caffe2/core/prof_dag_counters.h
@@ -0,0 +1,104 @@
+#ifndef PROF_DAG_COUNTERS_H
+#define PROF_DAG_COUNTERS_H
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/timer.h"
+#include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/prof_dag.pb.h"
+
+#include <unordered_map>
+
+namespace caffe2 {
+
+class ProfDAGStats {
+ public:
+ ProfDAGStats() : sum_(0.0), sqrsum_(0.0), cnt_(0) {}
+ explicit ProfDAGStats(float time_ms)
+ : sum_(time_ms), sqrsum_(time_ms * time_ms), cnt_(1) {}
+
+ ProfDAGStats& operator+=(const ProfDAGStats& rhs) {
+ sum_ += rhs.sum_;
+ sqrsum_ += rhs.sqrsum_;
+ cnt_ += rhs.cnt_;
+ return *this;
+ }
+
+ std::pair<float, float> computeMoments() const {
+ CAFFE_ENFORCE_GT(cnt_, 0);
+ float mean = sum_ / cnt_;
+ float stddev = std::sqrt(sqrsum_ / cnt_ - mean * mean);
+ return {mean, stddev};
+ }
+
+ float sum() const {
+ return sum_;
+ }
+
+ float sqrsum() const {
+ return sqrsum_;
+ }
+
+ size_t cnt() const {
+ return cnt_;
+ }
+
+ private:
+ float sum_;
+ float sqrsum_;
+ size_t cnt_;
+};
+
+/**
+ * A simple wrapper around prof_dag's counters
+ */
+class ProfDAGCounters {
+ public:
+ explicit ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def);
+
+ // Collects the execution time per each operator type
+ ProfDAGProtos GetOperatorStats() const;
+
+ // Collects the execution time of each operator, the output is
+ // formatted as a map: (netName__opIndex__opType, cost)
+ ProfDAGProtos GetPerOperatorCost() const;
+
+ // ReportRunStart/End are called at the beginning and at the end of
+ // each net's run
+ void ReportRunStart();
+ void ReportRunEnd();
+
+ void AddPerOpStartTime(size_t op_id);
+ void AddPerOpEndTime(size_t op_id);
+ void AddPerOpAsyncEndTime(size_t op_id);
+
+ void PrintStats();
+
+ private:
+ ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats)
+ const;
+
+ std::vector<std::string> op_types_;
+
+ // Cumulative stats per operator instance of the net
+ std::vector<ProfDAGStats> time_per_op_total_;
+
+ // Cumulative stats per unique operator type
+ CaffeMap<std::string, ProfDAGStats> time_per_op_type_total_;
+
+ CaffeMap<std::string, ProfDAGStats> times_per_run_per_type_total_;
+
+ std::string net_name_;
+
+ int num_runs_;
+ Timer timer_;
+ ProfDAGStats runtime_stats_;
+
+ std::vector<float> op_start_times_run_;
+ std::vector<float> op_end_times_run_;
+ std::vector<float> op_async_end_times_run_;
+};
+
+} // namespace caffe2
+
+#endif