Implement CUDA version of GRU operator

Summary: Add CUDA version of GRU operator

Reviewed By: jamesr66a

Differential Revision: D5571043

fbshipit-source-id: 332aa64fc8a9116cc33382f2b2907080e58c13b3
diff --git a/caffe2/operators/gru_unit_op.cc b/caffe2/operators/gru_unit_op.cc
index 5f17efa..2ee02cc 100644
--- a/caffe2/operators/gru_unit_op.cc
+++ b/caffe2/operators/gru_unit_op.cc
@@ -1,107 +1,6 @@
 #include "gru_unit_op.h"
 
 namespace caffe2 {
-namespace detail {
-
-template <typename T>
-inline T sigmoid(T x) {
-  return 1.0f / (1.0f + exp(-x));
-}
-
-template <typename T>
-inline T host_tanh(T x) {
-  return 2.0f * sigmoid(2.0f * x) - 1.0f;
-}
-
-template <typename T, typename Context>
-void GRUUnit(
-    int N,
-    int D,
-    int t,
-    const T* H_prev,
-    const T* X,
-    const int32_t* seqLengths,
-    bool drop_states,
-    T* H) {
-  for (int n = 0; n < N; ++n) {
-    const bool valid = t < seqLengths[n];
-
-    for (int d = 0; d < D; ++d) {
-      if (valid == false) {
-        if (drop_states) {
-          H[d] = 0;
-        } else {
-          H[d] = H_prev[d];
-        }
-      } else {
-        const T update = X[1 * D + d];
-        const T output = X[2 * D + d];
-        H[d] = H_prev[d] * sigmoid(update) +
-            host_tanh(output) * (1.0f - sigmoid(update));
-      }
-    }
-
-    H_prev += D;
-    X += 3 * D;
-    H += D;
-  }
-}
-
-template <typename T, typename Context>
-void GRUUnitGradient(
-    int N,
-    int D,
-    int t,
-    const T* H_prev,
-    const T* X,
-    const int32_t* seqLengths,
-    const T* H,
-    const T* H_diff,
-    bool drop_states,
-    T* H_prev_diff,
-    T* X_diff) {
-  for (int n = 0; n < N; ++n) {
-    const bool valid = t < seqLengths[n];
-
-    for (int d = 0; d < D; ++d) {
-      T* h_prev_diff = H_prev_diff + d;
-      T* reset_diff = X_diff + 0 * D + d;
-      T* update_diff = X_diff + 1 * D + d;
-      T* output_diff = X_diff + 2 * D + d;
-
-      if (!valid) {
-        if (drop_states) {
-          *h_prev_diff = 0;
-        } else {
-          *h_prev_diff = H_diff[d];
-        }
-        *reset_diff = 0;
-        *update_diff = 0;
-        *output_diff = 0;
-      } else {
-        // Calculate Gate Outputs
-        const T u = sigmoid(X[1 * D + d]);
-        const T o = host_tanh(X[2 * D + d]);
-
-        *h_prev_diff = H_diff[d] * u;
-        *reset_diff = 0; // 0 contribution to gradient from this operation
-        *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
-        *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
-      }
-    }
-
-    H_prev += D;
-    X += 3 * D;
-    H += D;
-    H_diff += D;
-    X_diff += 3 * D;
-    H_prev_diff += D;
-  }
-}
-
-} // namespace detail
-
-namespace {
 REGISTER_CPU_OPERATOR(GRUUnit, GRUUnitOp<float, CPUContext>);
 OPERATOR_SCHEMA(GRUUnit)
     .NumInputs(4)
@@ -147,5 +46,4 @@
   }
 };
 REGISTER_GRADIENT(GRUUnit, GetGRUUnitGradient);
-}
 } // namespace caffe2
diff --git a/caffe2/operators/gru_unit_op.h b/caffe2/operators/gru_unit_op.h
index efde370..34c25bc 100644
--- a/caffe2/operators/gru_unit_op.h
+++ b/caffe2/operators/gru_unit_op.h
@@ -8,6 +8,16 @@
 namespace caffe2 {
 namespace detail {
 
+template <typename T>
+inline T sigmoid(T x) {
+  return 1.0f / (1.0f + exp(-x));
+}
+
+template <typename T>
+inline T host_tanh(T x) {
+  return 2.0f * sigmoid(2.0f * x) - 1.0f;
+}
+
 template <typename T, typename Context>
 void GRUUnit(
     int N,
@@ -17,7 +27,32 @@
     const T* X,
     const int32_t* seqLengths,
     bool drop_states,
-    T* H);
+    T* H,
+    Context* /*context*/) {
+  for (int n = 0; n < N; ++n) {
+    const bool valid = t < seqLengths[n];
+
+    for (int d = 0; d < D; ++d) {
+      if (!valid) {
+        if (drop_states) {
+          H[d] = 0;
+        } else {
+          H[d] = H_prev[d];
+        }
+      } else {
+        const T update = X[1 * D + d];
+        const T output = X[2 * D + d];
+        T sigmoid_update = sigmoid(update);
+        H[d] = H_prev[d] * sigmoid_update +
+            host_tanh(output) * (1.0f - sigmoid_update);
+      }
+    }
+
+    H_prev += D;
+    X += 3 * D;
+    H += D;
+  }
+}
 
 template <typename T, typename Context>
 void GRUUnitGradient(
@@ -31,9 +66,48 @@
     const T* H_diff,
     bool drop_states,
     T* H_prev_diff,
-    T* X_diff);
+    T* X_diff,
+    Context* /*context*/) {
+  for (int n = 0; n < N; ++n) {
+    const bool valid = t < seqLengths[n];
 
-}; // namespace detail
+    for (int d = 0; d < D; ++d) {
+      T* h_prev_diff = H_prev_diff + d;
+      T* reset_diff = X_diff + 0 * D + d;
+      T* update_diff = X_diff + 1 * D + d;
+      T* output_diff = X_diff + 2 * D + d;
+
+      if (!valid) {
+        if (drop_states) {
+          *h_prev_diff = 0;
+        } else {
+          *h_prev_diff = H_diff[d];
+        }
+        *reset_diff = 0;
+        *update_diff = 0;
+        *output_diff = 0;
+      } else {
+        // Calculate Gate Outputs
+        const T u = sigmoid(X[1 * D + d]);
+        const T o = host_tanh(X[2 * D + d]);
+
+        *h_prev_diff = H_diff[d] * u;
+        *reset_diff = 0; // 0 contribution to gradient from this operation
+        *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
+        *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
+      }
+    }
+
+    H_prev += D;
+    X += 3 * D;
+    H += D;
+    H_diff += D;
+    X_diff += 3 * D;
+    H_prev_diff += D;
+  }
+}
+
+} // namespace detail
 
 template <typename T, typename Context>
 class GRUUnitOp : public Operator<Context> {
@@ -64,7 +138,7 @@
     auto* H = Output(HIDDEN_T)->template mutable_data<T>();
 
     detail::GRUUnit<T, Context>(
-        N, D, t, H_prev, X, seqLengths, drop_states_, H);
+        N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
     return true;
   }
 
@@ -118,7 +192,8 @@
         H_diff,
         drop_states_,
         H_prev_diff,
-        X_diff);
+        X_diff,
+        &context_);
     return true;
   }
 
diff --git a/caffe2/operators/gru_unit_op_gpu.cu b/caffe2/operators/gru_unit_op_gpu.cu
new file mode 100644
index 0000000..9e99be1
--- /dev/null
+++ b/caffe2/operators/gru_unit_op_gpu.cu
@@ -0,0 +1,140 @@
+#include <algorithm>
+#include <cmath>
+#include <vector>
+#include "caffe2/core/context_gpu.h"
+#include "gru_unit_op.h"
+
+namespace caffe2 {
+
+namespace detail {
+
+template <typename Dtype>
+__device__ Dtype cuda_sigmoid(const Dtype x) {
+  return Dtype(1) / (Dtype(1) + exp(-x));
+}
+
+template <typename T>
+__global__ void GRUUnitKernel(
+    const int ND,
+    const int dim,
+    const int t,
+    const T* H_prev,
+    const T* X,
+    const int32_t* seqLengths,
+    bool drop_states,
+    T* H) {
+  // index is virtual thread ID in range [0, ND)
+  CUDA_1D_KERNEL_LOOP(index, ND) {
+    const int n = index / dim;
+    const int d = index % dim;
+    const bool valid = t < seqLengths[n];
+    if (!valid) {
+      H[index] = H_prev[index] * !drop_states;
+    } else {
+      const T* X_offset = X + 3 * dim * n;
+      const T update = X_offset[1 * dim + d];
+      const T output = X_offset[2 * dim + d];
+      T sigmoid_update = cuda_sigmoid(update);
+      H[index] = H_prev[index] * sigmoid_update +
+          tanh(output) * (1.0f - sigmoid_update);
+    }
+  }
+}
+
+template <typename T>
+__global__ void GRUUnitGradientKernel(
+    const int ND,
+    const int dim,
+    const int t,
+    const T* H_prev,
+    const T* X,
+    const int32_t* seqLengths,
+    const T* H,
+    const T* H_diff,
+    bool drop_states,
+    T* H_prev_diff,
+    T* X_diff) {
+  CUDA_1D_KERNEL_LOOP(index, ND) {
+    const int n = index / dim;
+    const bool valid = t < seqLengths[n];
+    const int d = index % dim;
+    const T* X_offset = X + 3 * dim * n;
+    T* h_prev_diff = H_prev_diff + index;
+    T* X_diff_offset = X_diff + 3 * dim * n;
+    T* reset_diff = X_diff_offset + 0 * dim + d;
+    T* update_diff = X_diff_offset + 1 * dim + d;
+    T* output_diff = X_diff_offset + 2 * dim + d;
+
+    if (!valid) {
+      *h_prev_diff = H_diff[index] * !drop_states;
+      *reset_diff = 0;
+      *update_diff = 0;
+      *output_diff = 0;
+    } else {
+      const T u = cuda_sigmoid(X_offset[1 * dim + d]);
+      const T o = tanh(X_offset[2 * dim + d]);
+
+      *h_prev_diff = H_diff[index] * u;
+      *reset_diff = 0; // 0 contribution to gradient from this operation
+      *update_diff =
+          (H_diff[index] * H_prev[index] - H_diff[index] * o) * u * (1.0f - u);
+      *output_diff = H_diff[index] * (1.0f - u) * (1.0f - o * o);
+    }
+  }
+}
+
+template <>
+void GRUUnit<float, CUDAContext>(
+    int N,
+    int D,
+    int t,
+    const float* H_prev,
+    const float* X,
+    const int32_t* seqLengths,
+    bool drop_states,
+    float* H,
+    CUDAContext* context) {
+  GRUUnitKernel<float>
+      <<<CAFFE_GET_BLOCKS(N * D),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context->cuda_stream()>>>(
+          N * D, D, t, H_prev, X, seqLengths, drop_states, H);
+}
+
+template <>
+void GRUUnitGradient<float, CUDAContext>(
+    int N,
+    int D,
+    int t,
+    const float* H_prev,
+    const float* X,
+    const int32_t* seqLengths,
+    const float* H,
+    const float* H_diff,
+    bool drop_states,
+    float* H_prev_diff,
+    float* X_diff,
+    CUDAContext* context) {
+  GRUUnitGradientKernel<float>
+      <<<CAFFE_GET_BLOCKS(N * D),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context->cuda_stream()>>>(
+          N * D,
+          D,
+          t,
+          H_prev,
+          X,
+          seqLengths,
+          H,
+          H_diff,
+          drop_states,
+          H_prev_diff,
+          X_diff);
+}
+}
+
+REGISTER_CUDA_OPERATOR(GRUUnit, GRUUnitOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CUDAContext>);
+}
diff --git a/caffe2/python/operator_test/gru_test.py b/caffe2/python/operator_test/gru_test.py
index 3034500..5937f6a 100644
--- a/caffe2/python/operator_test/gru_test.py
+++ b/caffe2/python/operator_test/gru_test.py
@@ -3,11 +3,11 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
-from caffe2.python import workspace, scope, gru_cell
+from caffe2.python import workspace, core, scope, gru_cell
 from caffe2.python.model_helper import ModelHelper
 from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh, _prepare_rnn
 import caffe2.python.hypothesis_test_util as hu
-
+from caffe2.proto import caffe2_pb2
 
 from functools import partial
 from hypothesis import given
@@ -150,7 +150,7 @@
     return dims_.flatmap(create_input)
 
 
-def _prepare_gru_unit_op(n, d, outputs_with_grads,
+def _prepare_gru_unit_op(gc, n, d, outputs_with_grads,
                          forward_only=False, drop_states=False,
                          two_d_initial_states=None):
     print("Dims: (n,d) = ({},{})".format(n, d))
@@ -173,11 +173,13 @@
             )
         workspace.FeedBlob(
             hidden_t_prev,
-            generate_input_state(n, d).astype(np.float32)
+            generate_input_state(n, d).astype(np.float32),
+            device_option=gc
         )
         workspace.FeedBlob(
             gates_t,
-            generate_input_state(n, 3 * d).astype(np.float32)
+            generate_input_state(n, 3 * d).astype(np.float32),
+            device_option=gc
         )
 
         hidden_t = model.net.GRUUnit(
@@ -198,12 +200,15 @@
         # and generate some reasonable seq. lengths
         workspace.FeedBlob(
             seq_lengths,
-            np.random.randint(1, 10, size=(n,)).astype(np.int32)
+            np.random.randint(1, 10, size=(n,)).astype(np.int32),
+            device_option=gc
         )
         workspace.FeedBlob(
             timestep,
-            np.random.randint(1, 10, size=(1,)).astype(np.int32)
+            np.random.randint(1, 10, size=(1,)).astype(np.int32),
+            device_option=core.DeviceOption(caffe2_pb2.CPU),
         )
+        print("Feed {}".format(timestep))
 
     return hidden_t, model.net
 
@@ -215,9 +220,10 @@
         input_tensor=gru_unit_op_input(),
         fwd_only=st.booleans(),
         drop_states=st.booleans(),
+        **hu.gcs
     )
     @ht_settings(max_examples=15)
-    def test_gru_unit_op(self, input_tensor, fwd_only, drop_states, **kwargs):
+    def test_gru_unit_op(self, input_tensor, fwd_only, drop_states, gc, dc):
         outputs_with_grads = [0]
         ref = gru_unit
         ref = partial(ref)
@@ -227,22 +233,26 @@
         d = d // 3
         ref = partial(ref, drop_states=drop_states)
 
-        net = _prepare_gru_unit_op(n, d,
-                                   outputs_with_grads=outputs_with_grads,
-                                   forward_only=fwd_only,
-                                   drop_states=drop_states)[1]
+        with core.DeviceScope(gc):
+            net = _prepare_gru_unit_op(gc, n, d,
+                                       outputs_with_grads=outputs_with_grads,
+                                       forward_only=fwd_only,
+                                       drop_states=drop_states)[1]
         # here we don't provide a real input for the net but just for one of
         # its ops (RecurrentNetworkOp). So have to hardcode this name
         workspace.FeedBlob("test_name_scope/external/recurrent/i2h",
-                           input_tensor)
+                           input_tensor,
+                           device_option=gc)
+        print(str(net.Proto()))
         op = net._net.op[-1]
         inputs = [workspace.FetchBlob(name) for name in op.input]
 
         self.assertReferenceChecks(
-            hu.cpu_do,
+            gc,
             op,
             inputs,
             ref,
+            input_device_options={op.input[3]: hu.cpu_do},
             outputs_to_check=[0],
         )
 
@@ -251,19 +261,21 @@
             for param in range(2):
                 print("Check param {}".format(param))
                 self.assertGradientChecks(
-                    device_option=hu.cpu_do,
+                    device_option=gc,
                     op=op,
                     inputs=inputs,
                     outputs_to_check=param,
                     outputs_with_grads=outputs_with_grads,
                     threshold=0.0001,
                     stepsize=0.005,
+                    input_device_options={op.input[3]: hu.cpu_do},
                 )
 
     @given(
         input_tensor=gru_input(),
         fwd_only=st.booleans(),
         drop_states=st.booleans(),
+        **hu.gcs
     )
     @ht_settings(max_examples=15)
     def test_gru_main(self, **kwargs):
@@ -273,32 +285,33 @@
                            **kwargs)
 
     def gru_base(self, create_rnn, ref, outputs_with_grads,
-                  input_tensor, fwd_only, drop_states):
+                  input_tensor, fwd_only, drop_states, gc, dc):
         print("GRU test parameters: ", locals())
-
         t, n, d = input_tensor.shape
         assert d % 3 == 0
         d = d // 3
         ref = partial(ref, drop_states=drop_states)
-
-        net = _prepare_rnn(t, n, d, create_rnn,
-                            outputs_with_grads=outputs_with_grads,
-                            memory_optim=False,
-                            forget_bias=0.0,
-                            forward_only=fwd_only,
-                            drop_states=drop_states)[1]
+        with core.DeviceScope(gc):
+            net = _prepare_rnn(t, n, d, create_rnn,
+                                outputs_with_grads=outputs_with_grads,
+                                memory_optim=False,
+                                forget_bias=0.0,
+                                forward_only=fwd_only,
+                                drop_states=drop_states)[1]
         # here we don't provide a real input for the net but just for one of
         # its ops (RecurrentNetworkOp). So have to hardcode this name
         workspace.FeedBlob("test_name_scope/external/recurrent/i2h",
-                           input_tensor)
+                           input_tensor,
+                           device_option=gc)
         op = net._net.op[-1]
         inputs = [workspace.FetchBlob(name) for name in op.input]
 
         self.assertReferenceChecks(
-            hu.cpu_do,
+            gc,
             op,
             inputs,
             ref,
+            input_device_options={"timestep": hu.cpu_do},
             outputs_to_check=list(range(2)),
         )
 
@@ -307,11 +320,12 @@
             for param in range(2):
                 print("Check param {}".format(param))
                 self.assertGradientChecks(
-                    device_option=hu.cpu_do,
+                    device_option=gc,
                     op=op,
                     inputs=inputs,
                     outputs_to_check=param,
                     outputs_with_grads=outputs_with_grads,
                     threshold=0.001,
                     stepsize=0.005,
+                    input_device_options={"timestep": hu.cpu_do},
                 )