Add CUDNN to the gpu devices' default preferred engines

Summary: CUDNN should be almost always faster than the default implementation

Reviewed By: Yangqing

Differential Revision: D5633240

fbshipit-source-id: 99c45c04bf6a3c19f3f7eb27be1bb89344bc03d4
diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc
index fe52897..982078c 100644
--- a/caffe2/core/operator.cc
+++ b/caffe2/core/operator.cc
@@ -39,8 +39,17 @@
 }
 
 namespace {
-static PerOpEnginePrefType g_per_op_engine_pref{};
-static GlobalEnginePrefType g_global_engine_pref{};
+
+PerOpEnginePrefType& g_per_op_engine_pref() {
+  static auto* g_per_op_engine_pref_ = new PerOpEnginePrefType();
+  return *g_per_op_engine_pref_;
+}
+
+GlobalEnginePrefType& g_global_engine_pref() {
+  static auto* g_global_engine_pref_ =
+      new GlobalEnginePrefType{{DeviceType::CUDA, {"CUDNN"}}};
+  return *g_global_engine_pref_;
+}
 
 unique_ptr<OperatorBase> TryCreateOperator(
     const string& key, const OperatorDef& operator_def, Workspace* ws) {
@@ -91,14 +100,15 @@
     const auto op_def_engines = split(',', operator_def.engine());
     engines.insert(engines.end(), op_def_engines.begin(), op_def_engines.end());
   }
-  if (g_per_op_engine_pref.count(device_type) &&
-      g_per_op_engine_pref[device_type].count(op_type)) {
-    const auto& preferred_engines = g_per_op_engine_pref[device_type][op_type];
+  if (g_per_op_engine_pref().count(device_type) &&
+      g_per_op_engine_pref()[device_type].count(op_type)) {
+    const auto& preferred_engines =
+        g_per_op_engine_pref()[device_type][op_type];
     engines.insert(
         engines.end(), preferred_engines.begin(), preferred_engines.end());
   }
-  if (g_global_engine_pref.count(device_type)) {
-    const auto& preferred_engines = g_global_engine_pref[device_type];
+  if (g_global_engine_pref().count(device_type)) {
+    const auto& preferred_engines = g_global_engine_pref()[device_type];
     engines.insert(
         engines.end(), preferred_engines.begin(), preferred_engines.end());
   }
@@ -156,7 +166,7 @@
           " registry.");
     }
   }
-  g_per_op_engine_pref = per_op_engine_pref;
+  g_per_op_engine_pref() = per_op_engine_pref;
 }
 
 void SetGlobalEnginePref(const GlobalEnginePrefType& global_engine_pref) {
@@ -168,7 +178,7 @@
         device_type,
         " not registered.");
   }
-  g_global_engine_pref = global_engine_pref;
+  g_global_engine_pref() = global_engine_pref;
 }
 
 void SetEnginePref(
@@ -195,7 +205,7 @@
         " not registered in ",
         device_type,
         " registry.");
-    g_per_op_engine_pref[device_type][op_type] = device_pref_pair.second;
+    g_per_op_engine_pref()[device_type][op_type] = device_pref_pair.second;
   }
 }
 
diff --git a/caffe2/core/operator_gpu_test.cc b/caffe2/core/operator_gpu_test.cc
new file mode 100644
index 0000000..a3da21e
--- /dev/null
+++ b/caffe2/core/operator_gpu_test.cc
@@ -0,0 +1,60 @@
+#include <string>
+
+#include <gtest/gtest.h>
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+class JustTest : public OperatorBase {
+ public:
+  using OperatorBase::OperatorBase;
+  bool Run(int /* unused */ /*stream_id*/) override {
+    return true;
+  }
+  virtual std::string type() {
+    return "BASE";
+  }
+};
+
+class JustTestCUDA : public JustTest {
+ public:
+  using JustTest::JustTest;
+  bool Run(int /* unused */ /*stream_id*/) override {
+    return true;
+  }
+  std::string type() override {
+    return "CUDA";
+  }
+};
+
+class JustTestCUDNN : public JustTest {
+ public:
+  using JustTest::JustTest;
+  bool Run(int /* unused */ /*stream_id*/) override {
+    return true;
+  }
+  std::string type() override {
+    return "CUDNN";
+  }
+};
+
+OPERATOR_SCHEMA(JustTest).NumInputs(0, 1).NumOutputs(0, 1);
+REGISTER_CUDA_OPERATOR(JustTest, JustTestCUDA);
+REGISTER_CUDNN_OPERATOR(JustTest, JustTestCUDNN);
+
+TEST(EnginePrefTest, GPUDeviceDefaultPreferredEngines) {
+  OperatorDef op_def;
+  Workspace ws;
+  op_def.mutable_device_option()->set_device_type(CUDA);
+  op_def.set_type("JustTest");
+
+  {
+    const auto op = CreateOperator(op_def, &ws);
+    EXPECT_NE(nullptr, op.get());
+    // CUDNN should be taken as it's in the default global preferred engines
+    // list
+    EXPECT_EQ(static_cast<JustTest*>(op.get())->type(), "CUDNN");
+  }
+}
+
+} // namespace caffe2