add header for AsyncDAGNet

Summary: This diff adds a header file for net_gpu.cc so that the AsyncDAGNet class can be used to create other derived classes.

Reviewed By: ajtulloch

Differential Revision: D4230046

fbshipit-source-id: 379c3ff7ebb7aeeb4294f39e6f5d1ecad48b92f0
diff --git a/caffe2/core/net_gpu.cc b/caffe2/core/net_gpu.cc
index fa3c6e3..4093359 100644
--- a/caffe2/core/net_gpu.cc
+++ b/caffe2/core/net_gpu.cc
@@ -1,6 +1,5 @@
-#include "caffe2/core/net.h"
+#include "caffe2/core/net_gpu.h"
 
-#include "caffe2/core/context_gpu.h"
 #include "caffe2/core/flags.h"
 
 #include "caffe2/core/operator.h"
@@ -65,34 +64,9 @@
 
 #endif // ifdef CAFFE2_USE_NVTX
 
-struct Stream;
+} // namespace
 
-struct Event {
- public:
-  explicit Event(const DeviceOption& device_option) {
-    if (device_option.device_type() == CUDA) {
-      gpu_id_ = device_option.has_cuda_gpu_id() ? device_option.cuda_gpu_id()
-                                                : GetDefaultGPUID();
-      DeviceGuard g(gpu_id_);
-      CUDA_CHECK(cudaEventCreateWithFlags(
-          &event_, cudaEventDefault | cudaEventDisableTiming));
-    }
-  }
-
-  ~Event() {
-    if (event_) {
-      CUDA_CHECK(cudaEventDestroy(event_));
-    }
-  }
-
-  void record(const Stream& stream);
-
-  int gpu_id_{-1};
-  cudaEvent_t event_{nullptr};
-  bool outstanding_{false};
-  bool neverRecorded_{true};
-  DISABLE_COPY_AND_ASSIGN(Event);
-};
+namespace internal {
 
 struct Stream {
   explicit Stream(const DeviceOption& device_option) {
@@ -130,6 +104,16 @@
   DISABLE_COPY_AND_ASSIGN(Stream);
 };
 
+Event::Event(const DeviceOption& device_option) {
+  if (device_option.device_type() == CUDA) {
+    gpu_id_ = device_option.has_cuda_gpu_id() ? device_option.cuda_gpu_id()
+                                              : GetDefaultGPUID();
+    DeviceGuard g(gpu_id_);
+    CUDA_CHECK(cudaEventCreateWithFlags(
+        &event_, cudaEventDefault | cudaEventDisableTiming));
+  }
+}
+
 void Event::record(const Stream& stream) {
   if (outstanding_) {
     // TODO - should we do this?
@@ -155,102 +139,91 @@
   CUDA_CHECK(cudaEventRecord(event_, stream.stream_));
   outstanding_ = true;
 }
+
+} // namespace internal
+
+AsyncDAGNet::AsyncDAGNet(const NetDef& net_def, Workspace* ws)
+    : DAGNetBase(net_def, ws) {
+  VLOG(1) << "Constructing Async DAG Net " << net_def.name();
+  eventRecorded_.resize(net_def.op_size());
+  events_.reserve(net_def.op_size());
+  for (int idx = 0; idx < net_def.op_size(); ++idx) {
+    const OperatorDef& op_def = net_def.op(idx);
+    if (!op_def.has_device_option() && net_def.has_device_option()) {
+      OperatorDef temp_def(op_def);
+      temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
+      events_.emplace_back(new internal::Event(temp_def.device_option()));
+    } else {
+      events_.emplace_back(new internal::Event(op_def.device_option()));
+    }
+  }
 }
 
-// Run an event-driven graph - before each operator chain, wait on
-// each parent operator for the chain source (Stream::wait), then
-// execute each operator (implicitly on the same stream).
-class AsyncDAGNet : public DAGNetBase {
- public:
-  AsyncDAGNet(const NetDef& net_def, Workspace* ws) : DAGNetBase(net_def, ws) {
-    VLOG(1) << "Constructing Async DAG Net " << net_def.name();
-    eventRecorded_.resize(net_def.op_size());
-    events_.reserve(net_def.op_size());
-    for (int idx = 0; idx < net_def.op_size(); ++idx) {
-      const OperatorDef& op_def = net_def.op(idx);
-      if (!op_def.has_device_option() && net_def.has_device_option()) {
-        OperatorDef temp_def(op_def);
-        temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
-        events_.emplace_back(new Event(temp_def.device_option()));
-      } else {
-        events_.emplace_back(new Event(op_def.device_option()));
-      }
-    }
+bool AsyncDAGNet::RunAt(const std::vector<int>& chain) {
+  CAFFE_ENFORCE(!chain.empty(), "Chain should not be empty.");
+  const auto source_idx = chain.front();
+  internal::Stream stream{
+      operator_nodes_[source_idx].operator_->def().device_option()};
+  const auto& parents = operator_nodes_[source_idx].parents_;
+  // Help ensure that our chaining is correct by verifying at least
+  // one parent recorded an event.
+  CAFFE_ENFORCE(
+      parents.empty() || std::any_of(
+                             parents.begin(),
+                             parents.end(),
+                             [this](int p) { return eventRecorded_[p]; }),
+      "None of the parent is recorded for an event.");
+
+  for (auto source_parent_idx : operator_nodes_[source_idx].parents_) {
+    ProfiledRange r(
+        operator_nodes_[source_parent_idx].operator_->def(), kWaitColor);
+    stream.wait(events_[source_parent_idx].get());
   }
 
-  bool RunAt(const std::vector<int>& chain) override {
-    CAFFE_ENFORCE(!chain.empty(), "Chain should not be empty.");
-    const auto source_idx = chain.front();
-    Stream stream{operator_nodes_[source_idx].operator_->def().device_option()};
-    const auto& parents = operator_nodes_[source_idx].parents_;
-    // Help ensure that our chaining is correct by verifying at least
-    // one parent recorded an event.
-    CAFFE_ENFORCE(
-        parents.empty() || std::any_of(
-                               parents.begin(),
-                               parents.end(),
-                               [this](int p) { return eventRecorded_[p]; }),
-        "None of the parent is recorded for an event.");
-
-    for (auto source_parent_idx : operator_nodes_[source_idx].parents_) {
-      ProfiledRange r(
-          operator_nodes_[source_parent_idx].operator_->def(), kWaitColor);
-      stream.wait(events_[source_parent_idx].get());
-    }
-
-    // We've waited on all our parent indices.
-    bool success = true;
-    for (auto idx : chain) {
-      ProfiledRange r(operator_nodes_[idx].operator_->def(), kRunColor);
-      success &= operator_nodes_[idx].operator_->RunAsync();
-    }
-
-    // Record an event for the sink of the chain.
-    const auto& sink_idx = chain.back();
-    {
-      ProfiledRange r(operator_nodes_[sink_idx].operator_->def(), kRecordColor);
-      events_[sink_idx]->record(stream);
-    }
-    CAFFE_ENFORCE(
-        !eventRecorded_[sink_idx],
-        "An event for ",
-        sink_idx,
-        " should not be recorded.");
-    eventRecorded_[sink_idx] = 1;
-    return success;
+  // We've waited on all our parent indices.
+  bool success = true;
+  for (auto idx : chain) {
+    ProfiledRange r(operator_nodes_[idx].operator_->def(), kRunColor);
+    success &= operator_nodes_[idx].operator_->RunAsync();
   }
 
-  bool Run() override {
-    // Reset the event tracking at each iteration
-    eventRecorded_.assign(eventRecorded_.size(), 0);
-
-    const auto result = DAGNetBase::Run();
-
-    // Synchronize execution of the network with respect to the host.
-    DeviceOption device_option;
-    device_option.set_device_type(CPU);
-    Stream stream{device_option};
-
-    // Potential optimization: we can pre-compute outstanding events.
-    for (auto i = 0; i < events_.size(); ++i) {
-      auto& event = events_[i];
-      if (event->outstanding_) {
-        VLOG(2) << "Synchronizing host on outstanding event";
-        ProfiledRange r(operator_nodes_[i].operator_->def(), kWaitColor);
-        stream.wait(event.get());
-      }
-    }
-    return result;
+  // Record an event for the sink of the chain.
+  const auto& sink_idx = chain.back();
+  {
+    ProfiledRange r(operator_nodes_[sink_idx].operator_->def(), kRecordColor);
+    events_[sink_idx]->record(stream);
   }
+  CAFFE_ENFORCE(
+      !eventRecorded_[sink_idx],
+      "An event for ",
+      sink_idx,
+      " should not be recorded.");
+  eventRecorded_[sink_idx] = 1;
+  return success;
+}
 
- protected:
-  // Tracks whether a given op has had an event recorded in each
-  // RunAt() iteration.
+bool AsyncDAGNet::Run() {
+  // Reset the event tracking at each iteration
+  eventRecorded_.assign(eventRecorded_.size(), 0);
 
-  std::vector<int32_t> eventRecorded_;
-  std::vector<std::unique_ptr<Event>> events_;
-  DISABLE_COPY_AND_ASSIGN(AsyncDAGNet);
-};
+  const auto result = DAGNetBase::Run();
+
+  // Synchronize execution of the network with respect to the host.
+  DeviceOption device_option;
+  device_option.set_device_type(CPU);
+  internal::Stream stream{device_option};
+
+  // Potential optimization: we can pre-compute outstanding events.
+  for (auto i = 0; i < events_.size(); ++i) {
+    auto& event = events_[i];
+    if (event->outstanding_) {
+      VLOG(2) << "Synchronizing host on outstanding event";
+      ProfiledRange r(operator_nodes_[i].operator_->def(), kWaitColor);
+      stream.wait(event.get());
+    }
+  }
+  return result;
+}
 
 REGISTER_NET(async_dag, AsyncDAGNet);
 }
diff --git a/caffe2/core/net_gpu.h b/caffe2/core/net_gpu.h
new file mode 100644
index 0000000..ba372b7
--- /dev/null
+++ b/caffe2/core/net_gpu.h
@@ -0,0 +1,52 @@
+#ifndef CAFFE2_CORE_NET_GPU_H_
+#define CAFFE2_CORE_NET_GPU_H_
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/core/net.h"
+
+namespace caffe2 {
+
+namespace internal {
+
+struct Stream;
+
+struct Event {
+ public:
+  explicit Event(const DeviceOption& device_option);
+  ~Event() {
+    if (event_) {
+      CUDA_CHECK(cudaEventDestroy(event_));
+    }
+  }
+
+  void record(const Stream& stream);
+
+  int gpu_id_{-1};
+  cudaEvent_t event_{nullptr};
+  bool outstanding_{false};
+  bool neverRecorded_{true};
+  DISABLE_COPY_AND_ASSIGN(Event);
+};
+
+} // namespace internal
+
+// Run an event-driven graph - before each operator chain, wait on
+// each parent operator for the chain source (Stream::wait), then
+// execute each operator (implicitly on the same stream).
+class AsyncDAGNet : public DAGNetBase {
+ public:
+  AsyncDAGNet(const NetDef& net_def, Workspace* ws);
+  bool RunAt(const std::vector<int>& chain) override;
+  bool Run() override;
+
+ protected:
+  // Tracks whether a given op has had an event recorded in each
+  // RunAt() iteration.
+  std::vector<int32_t> eventRecorded_;
+  std::vector<std::unique_ptr<internal::Event>> events_;
+  DISABLE_COPY_AND_ASSIGN(AsyncDAGNet);
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_CORE_NET_GPU_H_