Registering GPU version of PackSegments using GPUFallbackOp

Summary: Creating PackSegments and UnpackSegments GPU operators using GPUFallbackOp for now. The op does mainly copying of blobs and this is a reasonable solution until we have a CUDA op.

Reviewed By: pietern

Differential Revision: D4761589

fbshipit-source-id: dd483b9e34ecb6b53925405e5b4c24859c549606
diff --git a/caffe2/operators/pack_segments.cc b/caffe2/operators/pack_segments.cc
index ce7d6ff..81ccd63 100644
--- a/caffe2/operators/pack_segments.cc
+++ b/caffe2/operators/pack_segments.cc
@@ -1,155 +1,9 @@
-#include <atomic>
-#include <limits>
-#include <mutex>
-#include <unordered_map>
-#include <vector>
-#include "caffe2/core/operator.h"
-#include "caffe2/core/tensor.h"
-#include "caffe2/utils/math.h"
+#include "caffe2/operators/pack_segments.h"
 
 namespace caffe2 {
 
 namespace {
 
-template <class Context>
-class PackSegmentsOp final : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
-  USE_DISPATCH_HELPER;
-
-  PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
-      : Operator<Context>(operator_def, ws),
-      pad_minf_(
-        OperatorBase::GetSingleArgument<bool>("pad_minf", false)) {
-          if (pad_minf_) {
-            padding_ = -1.0 * std::numeric_limits<float>::infinity();
-          } else {
-            padding_ = 0;
-          }
-        }
-
-
-  bool RunOnDevice() override {
-    return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
-  }
-
-  template <typename T>
-  bool DoRunWithType() {
-    const auto& data = Input(DATA);
-    const auto& lengths = Input(LENGTHS);
-    auto* output = Output(0);
-
-    CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
-    CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
-
-    // Find the length of the longest sequence.
-    const T* l = lengths.template data<T>();
-    T max_length = 0;
-    for (T i = 0; i < lengths.dim(0); ++i) {
-      max_length = std::max(max_length, l[i]);
-    }
-
-    auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
-    shape[0] = max_length;
-    shape.insert(shape.begin(), lengths.size());
-    output->Resize(shape);
-    // create output tensor
-    auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
-
-    if (!data.dim(0)) {
-      // Return empty output (with the proper shape)
-      return true;
-    }
-
-    // Do padding
-    if (output->template IsType<float>()) {
-      math::Set<float, Context>(
-          output->size(),
-          padding_,
-          output->template mutable_data<float>(),
-          &context_);
-    }
-
-    int block_size = data.size() / data.dim(0);
-    int block_bytesize = data.nbytes() / data.dim(0);
-    const auto* d = static_cast<const char*>(data.raw_data());
-    int start = 0;
-    for (int i = 0; i < lengths.dim(0); ++i) {
-      context_.template CopyItems<Context, Context>(
-          data.meta(),
-          l[i] * block_size,
-          d + block_bytesize * start,
-          out + block_bytesize * max_length * i);
-      start += l[i];
-    }
-
-    return true;
-  }
-
-  INPUT_TAGS(LENGTHS, DATA);
-  private:
-    bool pad_minf_;
-    float padding_;
-};
-
-template <class Context>
-class UnpackSegmentsOp final : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
-  USE_DISPATCH_HELPER;
-
-  bool RunOnDevice() override {
-    return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
-  }
-
-  template <typename T>
-  bool DoRunWithType() {
-    const auto& data = Input(DATA);
-    const auto& lengths = Input(LENGTHS);
-    auto* output = Output(0);
-
-    CAFFE_ENFORCE(data.ndim() >= 2, "DATA should be at least 2-D");
-    CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
-
-    const T* l = lengths.template data<T>();
-
-    T max_length = 0;
-    for (T i = 0; i < lengths.dim(0); ++i) {
-      max_length = std::max(max_length, l[i]);
-    }
-    T total_l = std::accumulate(l, l + lengths.dim(0), 0);
-
-    auto shape = data.dims();
-    CAFFE_ENFORCE(
-        shape[0] == lengths.dim(0), "LENGTH should match DATA in dimension 0");
-    shape.erase(shape.begin());
-    shape[0] = total_l;
-    output->Resize(shape);
-    // create output tensor
-    auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
-    if (!(data.dim(0) * data.dim(1))) {
-      return true;
-    }
-    int block_size = data.size() / (data.dim(0) * data.dim(1));
-    int block_bytesize = data.nbytes() / (data.dim(0) * data.dim(1));
-    const auto* d = static_cast<const char*>(data.raw_data());
-    int start = 0;
-    for (int i = 0; i < lengths.dim(0); ++i) {
-      context_.template CopyItems<Context, Context>(
-          data.meta(),
-          l[i] * block_size,
-          d + block_bytesize * data.dim(1) * i,
-          out + block_bytesize * start);
-      start += l[i];
-    }
-    return true;
-  }
-
-  INPUT_TAGS(LENGTHS, DATA);
-};
-
 REGISTER_CPU_OPERATOR(PackSegments, PackSegmentsOp<CPUContext>);
 REGISTER_CPU_OPERATOR(UnpackSegments, UnpackSegmentsOp<CPUContext>);
 
diff --git a/caffe2/operators/pack_segments.h b/caffe2/operators/pack_segments.h
new file mode 100644
index 0000000..720ba99
--- /dev/null
+++ b/caffe2/operators/pack_segments.h
@@ -0,0 +1,154 @@
+#ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_
+#define CAFFE2_OPERATORS_PACK_SEGMENTS_H_
+
+#include <atomic>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <vector>
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <class Context>
+class PackSegmentsOp final : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
+  USE_DISPATCH_HELPER;
+
+  PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<Context>(operator_def, ws),
+        pad_minf_(OperatorBase::GetSingleArgument<bool>("pad_minf", false)) {
+    if (pad_minf_) {
+      padding_ = -1.0 * std::numeric_limits<float>::infinity();
+    } else {
+      padding_ = 0;
+    }
+  }
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& data = Input(DATA);
+    const auto& lengths = Input(LENGTHS);
+    auto* output = Output(0);
+
+    CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
+    CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+    // Find the length of the longest sequence.
+    const T* l = lengths.template data<T>();
+    T max_length = 0;
+    for (T i = 0; i < lengths.dim(0); ++i) {
+      max_length = std::max(max_length, l[i]);
+    }
+
+    auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
+    shape[0] = max_length;
+    shape.insert(shape.begin(), lengths.size());
+    output->Resize(shape);
+    // create output tensor
+    auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
+
+    if (!data.dim(0)) {
+      // Return empty output (with the proper shape)
+      return true;
+    }
+
+    // Do padding
+    if (output->template IsType<float>()) {
+      math::Set<float, Context>(
+          output->size(),
+          padding_,
+          output->template mutable_data<float>(),
+          &context_);
+    }
+
+    int block_size = data.size() / data.dim(0);
+    int block_bytesize = data.nbytes() / data.dim(0);
+    const auto* d = static_cast<const char*>(data.raw_data());
+    int start = 0;
+    for (int i = 0; i < lengths.dim(0); ++i) {
+      context_.template CopyItems<Context, Context>(
+          data.meta(),
+          l[i] * block_size,
+          d + block_bytesize * start,
+          out + block_bytesize * max_length * i);
+      start += l[i];
+    }
+
+    return true;
+  }
+
+  INPUT_TAGS(LENGTHS, DATA);
+
+ private:
+  bool pad_minf_;
+  float padding_;
+};
+
+template <class Context>
+class UnpackSegmentsOp final : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
+  USE_DISPATCH_HELPER;
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& data = Input(DATA);
+    const auto& lengths = Input(LENGTHS);
+    auto* output = Output(0);
+
+    CAFFE_ENFORCE(data.ndim() >= 2, "DATA should be at least 2-D");
+    CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+    const T* l = lengths.template data<T>();
+
+    T max_length = 0;
+    for (T i = 0; i < lengths.dim(0); ++i) {
+      max_length = std::max(max_length, l[i]);
+    }
+    T total_l = std::accumulate(l, l + lengths.dim(0), 0);
+
+    auto shape = data.dims();
+    CAFFE_ENFORCE(
+        shape[0] == lengths.dim(0), "LENGTH should match DATA in dimension 0");
+    shape.erase(shape.begin());
+    shape[0] = total_l;
+    output->Resize(shape);
+    // create output tensor
+    auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
+    if (!(data.dim(0) * data.dim(1))) {
+      return true;
+    }
+    int block_size = data.size() / (data.dim(0) * data.dim(1));
+    int block_bytesize = data.nbytes() / (data.dim(0) * data.dim(1));
+    const auto* d = static_cast<const char*>(data.raw_data());
+    int start = 0;
+    for (int i = 0; i < lengths.dim(0); ++i) {
+      context_.template CopyItems<Context, Context>(
+          data.meta(),
+          l[i] * block_size,
+          d + block_bytesize * data.dim(1) * i,
+          out + block_bytesize * start);
+      start += l[i];
+    }
+    return true;
+  }
+
+  INPUT_TAGS(LENGTHS, DATA);
+};
+
+} // namspace caffe2
+#endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
diff --git a/caffe2/operators/pack_segments_op_gpu.cc b/caffe2/operators/pack_segments_op_gpu.cc
new file mode 100644
index 0000000..86a19a9
--- /dev/null
+++ b/caffe2/operators/pack_segments_op_gpu.cc
@@ -0,0 +1,12 @@
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/operator_fallback_gpu.h"
+#include "caffe2/operators/pack_segments.h"
+
+namespace caffe2 {
+namespace {
+REGISTER_CUDA_OPERATOR(PackSegments, GPUFallbackOp<PackSegmentsOp<CPUContext>>);
+REGISTER_CUDA_OPERATOR(
+    UnpackSegments,
+    GPUFallbackOp<UnpackSegmentsOp<CPUContext>>);
+}
+}
diff --git a/caffe2/python/operator_test/pack_ops_test.py b/caffe2/python/operator_test/pack_ops_test.py
index 0b11d15..1712391 100644
--- a/caffe2/python/operator_test/pack_ops_test.py
+++ b/caffe2/python/operator_test/pack_ops_test.py
@@ -2,25 +2,58 @@
 from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
-import numpy as np
 
 from caffe2.python import core, workspace
-from caffe2.python.test_util import TestCase
+import caffe2.python.hypothesis_test_util as hu
+
+from hypothesis import given
+import numpy as np
 
 
-class TestTensorPackOps(TestCase):
-    def test_pack_ops(self):
-        workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))
-        workspace.FeedBlob(
-            'd',
-            np.array([
-                [1.0, 1.0],
-                [2.0, 2.0],
-                [2.0, 2.0],
-                [3.0, 3.0],
-                [3.0, 3.0],
-                [3.0, 3.0]],
-                dtype=np.float32))
+class TestTensorPackOps(hu.HypothesisTestCase):
+    @given(**hu.gcs)
+    def test_pack_ops(self, gc, dc):
+        lengths = np.array([1, 2, 3], dtype=np.int32)
+        data = np.array([
+            [1.0, 1.0],
+            [2.0, 2.0],
+            [2.0, 2.0],
+            [3.0, 3.0],
+            [3.0, 3.0],
+            [3.0, 3.0]], dtype=np.float32)
+        op = core.CreateOperator(
+            'PackSegments', ['l', 'd'], ['t'])
+        print(gc, dc)
+
+        def pack_segments_ref(lengths, data):
+            arr = []
+            constant_values = 0
+            if data.dtype.char == 'S':
+                constant_values = ''
+            for idx in range(np.size(lengths)):
+                chunk = data[np.sum(lengths[:idx]):np.sum(lengths[:idx + 1])]
+                pad_length = np.max(lengths) - lengths[idx]
+
+                # ((0, pad_length), (0, 0)) says add pad_length rows of padding
+                # below chunk and 0 rows of padding elsewhere
+                arr.append(np.pad(
+                    chunk,
+                    ((0, pad_length), (0, 0)),
+                    mode=str("constant"),
+                    constant_values=constant_values))
+            return [arr]
+        workspace.FeedBlob('l', lengths)
+        workspace.FeedBlob('d', data)
+        inputs = [lengths, data]
+        self.assertReferenceChecks(
+            device_option=gc,
+            op=op,
+            inputs=inputs,
+            reference=pack_segments_ref,
+        )
+        workspace.FeedBlob('l', lengths)
+        workspace.FeedBlob('d', data)
+
         workspace.RunOperatorOnce(core.CreateOperator(
             'PackSegments', ['l', 'd'], ['t']))
         workspace.RunOperatorOnce(core.CreateOperator(
@@ -67,6 +100,7 @@
         exponentiated = workspace.FetchBlob('r')
         assert(exponentiated[0, -1, 0] == 0.0)
 
+
 if __name__ == "__main__":
     import unittest
     unittest.main()