Registering GPU version of PackSegments using GPUFallbackOp
Summary: Creating PackSegments and UnpackSegments GPU operators using GPUFallbackOp for now. The op does mainly copying of blobs and this is a reasonable solution until we have a CUDA op.
Reviewed By: pietern
Differential Revision: D4761589
fbshipit-source-id: dd483b9e34ecb6b53925405e5b4c24859c549606
diff --git a/caffe2/operators/pack_segments.cc b/caffe2/operators/pack_segments.cc
index ce7d6ff..81ccd63 100644
--- a/caffe2/operators/pack_segments.cc
+++ b/caffe2/operators/pack_segments.cc
@@ -1,155 +1,9 @@
-#include <atomic>
-#include <limits>
-#include <mutex>
-#include <unordered_map>
-#include <vector>
-#include "caffe2/core/operator.h"
-#include "caffe2/core/tensor.h"
-#include "caffe2/utils/math.h"
+#include "caffe2/operators/pack_segments.h"
namespace caffe2 {
namespace {
-template <class Context>
-class PackSegmentsOp final : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
- USE_DISPATCH_HELPER;
-
- PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
- pad_minf_(
- OperatorBase::GetSingleArgument<bool>("pad_minf", false)) {
- if (pad_minf_) {
- padding_ = -1.0 * std::numeric_limits<float>::infinity();
- } else {
- padding_ = 0;
- }
- }
-
-
- bool RunOnDevice() override {
- return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
- }
-
- template <typename T>
- bool DoRunWithType() {
- const auto& data = Input(DATA);
- const auto& lengths = Input(LENGTHS);
- auto* output = Output(0);
-
- CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
- CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
-
- // Find the length of the longest sequence.
- const T* l = lengths.template data<T>();
- T max_length = 0;
- for (T i = 0; i < lengths.dim(0); ++i) {
- max_length = std::max(max_length, l[i]);
- }
-
- auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
- shape[0] = max_length;
- shape.insert(shape.begin(), lengths.size());
- output->Resize(shape);
- // create output tensor
- auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
-
- if (!data.dim(0)) {
- // Return empty output (with the proper shape)
- return true;
- }
-
- // Do padding
- if (output->template IsType<float>()) {
- math::Set<float, Context>(
- output->size(),
- padding_,
- output->template mutable_data<float>(),
- &context_);
- }
-
- int block_size = data.size() / data.dim(0);
- int block_bytesize = data.nbytes() / data.dim(0);
- const auto* d = static_cast<const char*>(data.raw_data());
- int start = 0;
- for (int i = 0; i < lengths.dim(0); ++i) {
- context_.template CopyItems<Context, Context>(
- data.meta(),
- l[i] * block_size,
- d + block_bytesize * start,
- out + block_bytesize * max_length * i);
- start += l[i];
- }
-
- return true;
- }
-
- INPUT_TAGS(LENGTHS, DATA);
- private:
- bool pad_minf_;
- float padding_;
-};
-
-template <class Context>
-class UnpackSegmentsOp final : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
- USE_DISPATCH_HELPER;
-
- bool RunOnDevice() override {
- return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
- }
-
- template <typename T>
- bool DoRunWithType() {
- const auto& data = Input(DATA);
- const auto& lengths = Input(LENGTHS);
- auto* output = Output(0);
-
- CAFFE_ENFORCE(data.ndim() >= 2, "DATA should be at least 2-D");
- CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
-
- const T* l = lengths.template data<T>();
-
- T max_length = 0;
- for (T i = 0; i < lengths.dim(0); ++i) {
- max_length = std::max(max_length, l[i]);
- }
- T total_l = std::accumulate(l, l + lengths.dim(0), 0);
-
- auto shape = data.dims();
- CAFFE_ENFORCE(
- shape[0] == lengths.dim(0), "LENGTH should match DATA in dimension 0");
- shape.erase(shape.begin());
- shape[0] = total_l;
- output->Resize(shape);
- // create output tensor
- auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
- if (!(data.dim(0) * data.dim(1))) {
- return true;
- }
- int block_size = data.size() / (data.dim(0) * data.dim(1));
- int block_bytesize = data.nbytes() / (data.dim(0) * data.dim(1));
- const auto* d = static_cast<const char*>(data.raw_data());
- int start = 0;
- for (int i = 0; i < lengths.dim(0); ++i) {
- context_.template CopyItems<Context, Context>(
- data.meta(),
- l[i] * block_size,
- d + block_bytesize * data.dim(1) * i,
- out + block_bytesize * start);
- start += l[i];
- }
- return true;
- }
-
- INPUT_TAGS(LENGTHS, DATA);
-};
-
REGISTER_CPU_OPERATOR(PackSegments, PackSegmentsOp<CPUContext>);
REGISTER_CPU_OPERATOR(UnpackSegments, UnpackSegmentsOp<CPUContext>);
diff --git a/caffe2/operators/pack_segments.h b/caffe2/operators/pack_segments.h
new file mode 100644
index 0000000..720ba99
--- /dev/null
+++ b/caffe2/operators/pack_segments.h
@@ -0,0 +1,154 @@
+#ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_
+#define CAFFE2_OPERATORS_PACK_SEGMENTS_H_
+
+#include <atomic>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <vector>
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <class Context>
+class PackSegmentsOp final : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
+ USE_DISPATCH_HELPER;
+
+ PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ pad_minf_(OperatorBase::GetSingleArgument<bool>("pad_minf", false)) {
+ if (pad_minf_) {
+ padding_ = -1.0 * std::numeric_limits<float>::infinity();
+ } else {
+ padding_ = 0;
+ }
+ }
+
+ bool RunOnDevice() override {
+ return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
+ const auto& data = Input(DATA);
+ const auto& lengths = Input(LENGTHS);
+ auto* output = Output(0);
+
+ CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
+ CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+ // Find the length of the longest sequence.
+ const T* l = lengths.template data<T>();
+ T max_length = 0;
+ for (T i = 0; i < lengths.dim(0); ++i) {
+ max_length = std::max(max_length, l[i]);
+ }
+
+ auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
+ shape[0] = max_length;
+ shape.insert(shape.begin(), lengths.size());
+ output->Resize(shape);
+ // create output tensor
+ auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
+
+ if (!data.dim(0)) {
+ // Return empty output (with the proper shape)
+ return true;
+ }
+
+ // Do padding
+ if (output->template IsType<float>()) {
+ math::Set<float, Context>(
+ output->size(),
+ padding_,
+ output->template mutable_data<float>(),
+ &context_);
+ }
+
+ int block_size = data.size() / data.dim(0);
+ int block_bytesize = data.nbytes() / data.dim(0);
+ const auto* d = static_cast<const char*>(data.raw_data());
+ int start = 0;
+ for (int i = 0; i < lengths.dim(0); ++i) {
+ context_.template CopyItems<Context, Context>(
+ data.meta(),
+ l[i] * block_size,
+ d + block_bytesize * start,
+ out + block_bytesize * max_length * i);
+ start += l[i];
+ }
+
+ return true;
+ }
+
+ INPUT_TAGS(LENGTHS, DATA);
+
+ private:
+ bool pad_minf_;
+ float padding_;
+};
+
+template <class Context>
+class UnpackSegmentsOp final : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
+ USE_DISPATCH_HELPER;
+
+ bool RunOnDevice() override {
+ return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
+ const auto& data = Input(DATA);
+ const auto& lengths = Input(LENGTHS);
+ auto* output = Output(0);
+
+ CAFFE_ENFORCE(data.ndim() >= 2, "DATA should be at least 2-D");
+ CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+ const T* l = lengths.template data<T>();
+
+ T max_length = 0;
+ for (T i = 0; i < lengths.dim(0); ++i) {
+ max_length = std::max(max_length, l[i]);
+ }
+ T total_l = std::accumulate(l, l + lengths.dim(0), 0);
+
+ auto shape = data.dims();
+ CAFFE_ENFORCE(
+ shape[0] == lengths.dim(0), "LENGTH should match DATA in dimension 0");
+ shape.erase(shape.begin());
+ shape[0] = total_l;
+ output->Resize(shape);
+ // create output tensor
+ auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
+ if (!(data.dim(0) * data.dim(1))) {
+ return true;
+ }
+ int block_size = data.size() / (data.dim(0) * data.dim(1));
+ int block_bytesize = data.nbytes() / (data.dim(0) * data.dim(1));
+ const auto* d = static_cast<const char*>(data.raw_data());
+ int start = 0;
+ for (int i = 0; i < lengths.dim(0); ++i) {
+ context_.template CopyItems<Context, Context>(
+ data.meta(),
+ l[i] * block_size,
+ d + block_bytesize * data.dim(1) * i,
+ out + block_bytesize * start);
+ start += l[i];
+ }
+ return true;
+ }
+
+ INPUT_TAGS(LENGTHS, DATA);
+};
+
+} // namspace caffe2
+#endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
diff --git a/caffe2/operators/pack_segments_op_gpu.cc b/caffe2/operators/pack_segments_op_gpu.cc
new file mode 100644
index 0000000..86a19a9
--- /dev/null
+++ b/caffe2/operators/pack_segments_op_gpu.cc
@@ -0,0 +1,12 @@
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/operator_fallback_gpu.h"
+#include "caffe2/operators/pack_segments.h"
+
+namespace caffe2 {
+namespace {
+REGISTER_CUDA_OPERATOR(PackSegments, GPUFallbackOp<PackSegmentsOp<CPUContext>>);
+REGISTER_CUDA_OPERATOR(
+ UnpackSegments,
+ GPUFallbackOp<UnpackSegmentsOp<CPUContext>>);
+}
+}
diff --git a/caffe2/python/operator_test/pack_ops_test.py b/caffe2/python/operator_test/pack_ops_test.py
index 0b11d15..1712391 100644
--- a/caffe2/python/operator_test/pack_ops_test.py
+++ b/caffe2/python/operator_test/pack_ops_test.py
@@ -2,25 +2,58 @@
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
-import numpy as np
from caffe2.python import core, workspace
-from caffe2.python.test_util import TestCase
+import caffe2.python.hypothesis_test_util as hu
+
+from hypothesis import given
+import numpy as np
-class TestTensorPackOps(TestCase):
- def test_pack_ops(self):
- workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))
- workspace.FeedBlob(
- 'd',
- np.array([
- [1.0, 1.0],
- [2.0, 2.0],
- [2.0, 2.0],
- [3.0, 3.0],
- [3.0, 3.0],
- [3.0, 3.0]],
- dtype=np.float32))
+class TestTensorPackOps(hu.HypothesisTestCase):
+ @given(**hu.gcs)
+ def test_pack_ops(self, gc, dc):
+ lengths = np.array([1, 2, 3], dtype=np.int32)
+ data = np.array([
+ [1.0, 1.0],
+ [2.0, 2.0],
+ [2.0, 2.0],
+ [3.0, 3.0],
+ [3.0, 3.0],
+ [3.0, 3.0]], dtype=np.float32)
+ op = core.CreateOperator(
+ 'PackSegments', ['l', 'd'], ['t'])
+ print(gc, dc)
+
+ def pack_segments_ref(lengths, data):
+ arr = []
+ constant_values = 0
+ if data.dtype.char == 'S':
+ constant_values = ''
+ for idx in range(np.size(lengths)):
+ chunk = data[np.sum(lengths[:idx]):np.sum(lengths[:idx + 1])]
+ pad_length = np.max(lengths) - lengths[idx]
+
+ # ((0, pad_length), (0, 0)) says add pad_length rows of padding
+ # below chunk and 0 rows of padding elsewhere
+ arr.append(np.pad(
+ chunk,
+ ((0, pad_length), (0, 0)),
+ mode=str("constant"),
+ constant_values=constant_values))
+ return [arr]
+ workspace.FeedBlob('l', lengths)
+ workspace.FeedBlob('d', data)
+ inputs = [lengths, data]
+ self.assertReferenceChecks(
+ device_option=gc,
+ op=op,
+ inputs=inputs,
+ reference=pack_segments_ref,
+ )
+ workspace.FeedBlob('l', lengths)
+ workspace.FeedBlob('d', data)
+
workspace.RunOperatorOnce(core.CreateOperator(
'PackSegments', ['l', 'd'], ['t']))
workspace.RunOperatorOnce(core.CreateOperator(
@@ -67,6 +100,7 @@
exponentiated = workspace.FetchBlob('r')
assert(exponentiated[0, -1, 0] == 0.0)
+
if __name__ == "__main__":
import unittest
unittest.main()