fp16: RecurrentNetwork

Summary:
Was https://github.com/caffe2/caffe2/pull/1151
Closes https://github.com/caffe2/caffe2/pull/1192

Reviewed By: salexspb

Differential Revision: D5829775

Pulled By: akyrola

fbshipit-source-id: e0f7609317ca95faf9eb9c81b265d678a24a80e3
diff --git a/caffe2/operators/recurrent_network_op.cc b/caffe2/operators/recurrent_network_op.cc
index 7bf8f4f..97caa27 100644
--- a/caffe2/operators/recurrent_network_op.cc
+++ b/caffe2/operators/recurrent_network_op.cc
@@ -9,7 +9,7 @@
 namespace caffe2 {
 CAFFE_KNOWN_TYPE(detail::ScratchWorkspaces);
 
-REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<CPUContext>);
 OPERATOR_SCHEMA(RecurrentNetwork)
     .NumInputs(1, INT_MAX)
     .NumOutputs(2, INT_MAX)
@@ -36,12 +36,12 @@
 
 REGISTER_CPU_OPERATOR(
     RecurrentNetworkGradient,
-    RecurrentNetworkGradientOp<float, CPUContext>);
+    RecurrentNetworkGradientOp<CPUContext>);
 OPERATOR_SCHEMA(RecurrentNetworkGradient);
 
 REGISTER_CPU_OPERATOR(
     rnn_internal_accumulate_gradient_input,
-    AccumulateInputGradientOp<float, CPUContext>);
+    AccumulateInputGradientOp<CPUContext>);
 OPERATOR_SCHEMA(rnn_internal_accumulate_gradient_input)
     .NumInputs(3)
     .NumOutputs(1, INT_MAX)
@@ -51,7 +51,7 @@
 
 REGISTER_CPU_OPERATOR(
     rnn_internal_apply_link,
-    RNNApplyLinkOp<float, CPUContext>);
+    RNNApplyLinkOp<CPUContext>);
 OPERATOR_SCHEMA(rnn_internal_apply_link)
     .NumInputs(2)
     .NumOutputs(2)
diff --git a/caffe2/operators/recurrent_network_op.h b/caffe2/operators/recurrent_network_op.h
index 95e940d..8002d69 100644
--- a/caffe2/operators/recurrent_network_op.h
+++ b/caffe2/operators/recurrent_network_op.h
@@ -7,6 +7,7 @@
 #include "caffe2/core/tensor.h"
 #include "caffe2/operators/recurrent_network_executor.h"
 #include "google/protobuf/text_format.h"
+#include "caffe2/utils/conversions.h"
 
 CAFFE2_DECLARE_bool(caffe2_rnn_executor);
 
@@ -87,6 +88,18 @@
       dst->size());
 }
 
+template <typename T, class Context>
+void repeatCopy(
+    size_t repeat_n,
+    size_t n,
+    const T* src,
+    T* dst,
+    Context* context) {
+  for (int i = 0; i < repeat_n; ++i) {
+    context->template Copy<T, Context, Context>(n, src, dst + i * n);
+  }
+}
+
 /**
  * Copy external input to the step net into the first item of
  * (T + 1) X batch_size X input_size tensor
@@ -127,14 +140,14 @@
         input.template data<T>(),
         state->template mutable_data<T>());
   } else {
-    for (int i = 0; i < batchSize; ++i) {
-      // Usually, the initial state is the same for all inputs in the batch.
-      // So the op conveniently accepts 1-D input and copies it batchSize times.
-      context->template Copy<T, Context, Context>(
+    // Usually, the initial state is the same for all inputs in the batch.
+    // So the op conveniently accepts 1-D input and copies it batchSize times.
+    repeatCopy<T, Context>(
+          batchSize,
           stateSize,
           input.template data<T>(),
-          state->template mutable_data<T>() + i * stateSize);
-    }
+          state->template mutable_data<T>(),
+          context);
   }
 }
 
@@ -155,7 +168,7 @@
     std::vector<detail::Link>* links);
 } // namespace detail
 
-template <typename T, class Context>
+template <class Context>
 class RecurrentNetworkOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -273,7 +286,8 @@
     return links;
   }
 
-  bool RunOnDevice() {
+  template<typename T>
+  bool DoRunWithType() {
     const auto seqLen = Input(0).dim32(0);
     const auto batchSize = Input(0).dim32(1);
     for (const auto& ri : recurrentInputs_) {
@@ -356,6 +370,10 @@
     return true;
   }
 
+  bool RunOnDevice() {
+    return DoRunWithType<float>();
+  }
+
  protected:
   NetDef stepNetDef_;
   Workspace* sharedWs_;
@@ -368,7 +386,7 @@
   std::string timestep_;
 };
 
-template <typename T, class Context>
+template <class Context>
 class RecurrentNetworkGradientOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -618,7 +636,8 @@
     }
   }
 
-  bool RunOnDevice() {
+  template<typename T>
+  bool DoRunWithType() {
     const auto seqLen = Input(gradInputs_.size()).dim32(0);
     VLOG(1) << "seqLen: " << seqLen;
 
@@ -640,7 +659,10 @@
       auto* g = gBlob->template GetMutable<Tensor<Context>>();
       g->ResizeLike(p);
       math::Set<T, Context>(
-          g->size(), 0.0, g->template mutable_data<T>(), &context_);
+          g->size(),
+          convert::To<float,T>(0.0),
+          g->template mutable_data<T>(),
+          &context_);
     }
 
     for (auto& rg : recurrentGradients_) {
@@ -657,7 +679,7 @@
       // Fill the last timestep with zeros for the gradient
       math::Set<T, Context>(
           timestep,
-          0.0,
+          convert::To<float,T>(0.0),
           g->template mutable_data<T>() + (g->dim(0) - 1) * timestep,
           &context_);
     }
@@ -766,7 +788,11 @@
         // which sums up several tensors together instead of going 1 by 1
         const auto recurrentStateSize = Input(inputId).dim32(0);
 
-        math::Set<T, Context>(recurrentStateSize, 0.0, output_data, &context_);
+        math::Set<T, Context>(
+            recurrentStateSize,
+            convert::To<float,T>(0.0),
+            output_data,
+            &context_);
 
         math::AddStripedBatch<T, Context>(
             recurrentStateSize,
@@ -781,6 +807,10 @@
     return true;
   }
 
+  bool RunOnDevice() {
+    return DoRunWithType<float>();
+  }
+
  protected:
   NetDef stepNetDef_;
   Workspace* sharedWs_;
@@ -796,7 +826,7 @@
   std::vector<int32_t> gradInputs_;
 };
 
-template <typename T, class Context>
+template <class Context>
 class AccumulateInputGradientOp : public Operator<Context> {
  public:
   AccumulateInputGradientOp(const OperatorDef& def, Workspace* ws)
@@ -806,7 +836,8 @@
   }
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
-  bool RunOnDevice() override {
+  template<typename T>
+  bool DoRunWithType() {
     const auto t =
         OperatorBase::Input<Tensor<CPUContext>>(0).template data<int32_t>()[0];
     auto& og = Input(1);
@@ -831,11 +862,15 @@
     return true;
   }
 
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(1));
+  }
+
  private:
   int offset_;
 };
 
-template <typename T, class Context>
+template <class Context>
 class RNNApplyLinkOp : public Operator<Context> {
  public:
   RNNApplyLinkOp(const OperatorDef& def, Workspace* ws)
@@ -848,7 +883,8 @@
 
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
-  bool RunOnDevice() override {
+  template <typename T>
+  bool DoRunWithType() {
     // Both internal and external appear as both input and output to enforce
     // correct dependency computation.
     const auto t =
@@ -871,6 +907,10 @@
     return true;
   }
 
+  bool RunOnDevice() override {
+    return DoRunWithType<float>();
+  }
+
  private:
   int offset_;
   int window_;
diff --git a/caffe2/operators/recurrent_network_op_gpu.cc b/caffe2/operators/recurrent_network_op_gpu.cc
deleted file mode 100644
index 80e2a6d..0000000
--- a/caffe2/operators/recurrent_network_op_gpu.cc
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "caffe2/core/context_gpu.h"
-#include "caffe2/operators/recurrent_network_op.h"
-
-namespace caffe2 {
-REGISTER_CUDA_OPERATOR(
-    RecurrentNetwork,
-    RecurrentNetworkOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(
-    RecurrentNetworkGradient,
-    RecurrentNetworkGradientOp<float, CUDAContext>);
-
-REGISTER_CUDA_OPERATOR(
-    rnn_internal_accumulate_gradient_input,
-    AccumulateInputGradientOp<float, CUDAContext>);
-
-REGISTER_CUDA_OPERATOR(
-    rnn_internal_apply_link,
-    RNNApplyLinkOp<float, CUDAContext>);
-}
diff --git a/caffe2/operators/recurrent_network_op_gpu.cu b/caffe2/operators/recurrent_network_op_gpu.cu
new file mode 100644
index 0000000..d7c70dd
--- /dev/null
+++ b/caffe2/operators/recurrent_network_op_gpu.cu
@@ -0,0 +1,94 @@
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/recurrent_network_op.h"
+
+namespace caffe2 {
+
+namespace detail {
+
+template <typename T, typename Context>
+void initializeRecurrentInput(
+    const RecurrentInput& rc,
+    int32_t seqLen,
+    int32_t batchSize,
+    Workspace* ws,
+    Context* context);
+
+namespace {
+
+template <typename T>
+__global__
+void initRecurrentInput_kernel(
+    size_t stateSize,
+    const T* input,
+    T* state) {
+  // index into appropriate target buffer
+  const int block_id = blockIdx.x;
+  T* state_local = state + block_id*stateSize;
+
+  // copy
+  for (int idx=threadIdx.x; idx < stateSize; idx+=blockDim.x) {
+    state_local[idx] = input[idx];
+  }
+}
+
+
+}; // namespace
+
+template <>
+void repeatCopy(
+    size_t repeat_n,
+    size_t n,
+    const float* src,
+    float* dst,
+    CUDAContext* context) {
+    initRecurrentInput_kernel<float><<<repeat_n, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+        n, src, dst);
+}
+template <>
+void repeatCopy(
+    size_t repeat_n,
+    size_t n,
+    const float16* src,
+    float16* dst,
+    CUDAContext* context) {
+    initRecurrentInput_kernel<float16><<<repeat_n, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+        n, src, dst);
+}
+
+}; // namespace detail
+
+template <>
+bool RecurrentNetworkOp<CUDAContext>::RunOnDevice() {
+  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+}
+
+template <>
+bool RecurrentNetworkGradientOp<CUDAContext>::RunOnDevice() {
+  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+}
+
+template <>
+bool AccumulateInputGradientOp<CUDAContext>::RunOnDevice() {
+  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(1));
+}
+
+template <>
+bool RNNApplyLinkOp<CUDAContext>::RunOnDevice() {
+  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(1));
+}
+
+REGISTER_CUDA_OPERATOR(
+    RecurrentNetwork,
+    RecurrentNetworkOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(
+    RecurrentNetworkGradient,
+    RecurrentNetworkGradientOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(
+    rnn_internal_accumulate_gradient_input,
+    AccumulateInputGradientOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(
+    rnn_internal_apply_link,
+    RNNApplyLinkOp<CUDAContext>);
+
+
+} // namespace caffe2