CUDA support for PackSegments Op
Summary: Replace GPUFallbackOp by native CUDA implementation
Reviewed By: akyrola
Differential Revision: D6423200
fbshipit-source-id: 47dfecbc486e9a8bf0cc6b897ab8b6a2488caa34
diff --git a/caffe2/operators/pack_segments.cc b/caffe2/operators/pack_segments.cc
index 541085a..cf6bd6d 100644
--- a/caffe2/operators/pack_segments.cc
+++ b/caffe2/operators/pack_segments.cc
@@ -18,6 +18,87 @@
namespace caffe2 {
+template <>
+template <typename T>
+bool PackSegmentsOp<CPUContext>::DoRunWithType() {
+ return DispatchHelper<
+ TensorTypes2<char, int32_t, int64_t, float, std::string>,
+ T>::call(this, Input(DATA));
+}
+
+template <>
+template <typename T, typename Data_T>
+bool PackSegmentsOp<CPUContext>::DoRunWithType2() {
+ const auto& data = Input(DATA);
+ const auto& lengths = Input(LENGTHS);
+ auto* output = Output(0);
+ Tensor<CPUContext>* presence_mask = nullptr;
+ if (return_presence_mask_) {
+ presence_mask = Output(1);
+ }
+
+ CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
+ CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+ // Find the length of the longest sequence.
+ const T* l = lengths.template data<T>();
+ T max_length = 0;
+ for (T i = 0; i < lengths.dim(0); ++i) {
+ max_length = std::max(max_length, l[i]);
+ }
+
+ auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
+ shape[0] = max_length;
+ shape.insert(shape.begin(), lengths.size());
+ output->Resize(shape);
+
+ // create output tensor
+ auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
+
+ bool* presence_mask_data = nullptr;
+ if (return_presence_mask_) {
+ // Shape of presence is batch_size x max_len
+ std::vector<caffe2::TIndex> presence_shape{lengths.size(), max_length};
+ presence_mask->Resize(presence_shape);
+ presence_mask_data = presence_mask->template mutable_data<bool>();
+ }
+
+ if (!data.dim(0)) {
+ // Return empty output (with the proper shape)
+ return true;
+ }
+
+ // Do padding
+ if (output->template IsType<float>()) {
+ math::Set<float, CPUContext>(
+ output->size(),
+ padding_,
+ output->template mutable_data<float>(),
+ &context_);
+ }
+ if (return_presence_mask_) {
+ memset(presence_mask_data, (int)false, presence_mask->size());
+ }
+
+ int block_size = data.size() / data.dim(0);
+ int block_bytesize = data.nbytes() / data.dim(0);
+ const auto* d = static_cast<const char*>(data.raw_data());
+ int start = 0;
+ for (int i = 0; i < lengths.dim(0); ++i) {
+ context_.template CopyItems<CPUContext, CPUContext>(
+ data.meta(),
+ l[i] * block_size,
+ d + block_bytesize * start,
+ out + block_bytesize * max_length * i);
+ if (return_presence_mask_) {
+ memset(presence_mask_data + max_length * i, (int)true, l[i]);
+ }
+ start += l[i];
+ }
+
+ return true;
+}
+
REGISTER_CPU_OPERATOR(PackSegments, PackSegmentsOp<CPUContext>);
REGISTER_CPU_OPERATOR(UnpackSegments, UnpackSegmentsOp<CPUContext>);
diff --git a/caffe2/operators/pack_segments.cu b/caffe2/operators/pack_segments.cu
new file mode 100644
index 0000000..ea46199
--- /dev/null
+++ b/caffe2/operators/pack_segments.cu
@@ -0,0 +1,178 @@
+#include <cub/cub.cuh>
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/pack_segments.h"
+
+namespace caffe2 {
+
+namespace {
+
+template <typename T, typename Data_T>
+__global__ void PackSegmentsKernel(
+ const Data_T* data_ptr,
+ const T* lengths_ptr,
+ const T* lengths_cum_sum,
+ const T max_length,
+ const int64_t num_seq,
+ const int64_t cell_size,
+ Data_T padding,
+ Data_T* out_ptr) {
+ CUDA_1D_KERNEL_LOOP(i, num_seq * max_length * cell_size) {
+ int seq = (i / cell_size) / max_length;
+ int cell = (i / cell_size) % max_length;
+ int offset = i % cell_size;
+ if (cell >= lengths_ptr[seq]) {
+ out_ptr[i] = padding;
+ } else {
+ int32_t idx = (lengths_cum_sum[seq] + cell) * cell_size + offset;
+ out_ptr[i] = data_ptr[idx];
+ }
+ }
+}
+
+template <typename T>
+const T array_max(
+ const T* dev_array,
+ int64_t num_items,
+ Tensor<CUDAContext>& dev_max_buffer,
+ Tensor<CUDAContext>& dev_max,
+ Tensor<CPUContext>& host_max,
+ CUDAContext& context) {
+ // Retrieve buffer size
+ size_t temp_storage_bytes = 0;
+ cub::DeviceReduce::Max(
+ nullptr,
+ temp_storage_bytes,
+ dev_array,
+ dev_max.mutable_data<T>(),
+ num_items,
+ context.cuda_stream());
+
+ // Allocate temporary storage
+ auto buffer_size = (temp_storage_bytes + sizeof(T)) / sizeof(T);
+ dev_max_buffer.Resize(buffer_size);
+ void* dev_temp_storage = static_cast<void*>(dev_max_buffer.mutable_data<T>());
+
+ // Find maximum
+ cub::DeviceReduce::Max(
+ dev_temp_storage,
+ temp_storage_bytes,
+ dev_array,
+ dev_max.mutable_data<T>(),
+ num_items,
+ context.cuda_stream());
+
+ // Copy to host
+ host_max.CopyFrom<CUDAContext>(dev_max);
+ return *host_max.data<T>();
+}
+
+template <typename T>
+void array_prefix_sum_exclusive(
+ const T* dev_array,
+ const int32_t num_items,
+ Tensor<CUDAContext>& prefix_buffer,
+ Tensor<CUDAContext>& prefix_sum,
+ CUDAContext& context) {
+ // Retrieve buffer size
+ size_t temp_storage_bytes = 0;
+ prefix_sum.Resize(num_items);
+ cub::DeviceScan::ExclusiveSum(
+ nullptr,
+ temp_storage_bytes,
+ dev_array,
+ prefix_sum.mutable_data<T>(),
+ num_items,
+ context.cuda_stream());
+
+ // Allocate temporary storage
+ auto buffer_size = (temp_storage_bytes + sizeof(T)) / sizeof(T);
+ prefix_buffer.Resize(buffer_size);
+ void* dev_temp_storage = static_cast<void*>(prefix_buffer.mutable_data<T>());
+
+ // Exclusive sum
+ cub::DeviceScan::ExclusiveSum(
+ dev_temp_storage,
+ temp_storage_bytes,
+ dev_array,
+ prefix_sum.mutable_data<T>(),
+ num_items,
+ context.cuda_stream());
+}
+
+} // namespace
+
+template <>
+template <typename T>
+bool PackSegmentsOp<CUDAContext>::DoRunWithType() {
+ return DispatchHelper<TensorTypes2<char, int32_t, int64_t, float>, T>::call(
+ this, Input(DATA));
+}
+
+template <>
+template <typename T, typename Data_T>
+bool PackSegmentsOp<CUDAContext>::DoRunWithType2() {
+ const auto& data = Input(DATA);
+ const auto& lengths = Input(LENGTHS);
+ int64_t num_seq = lengths.dim(0);
+ const Data_T* data_ptr = data.data<Data_T>();
+ const T* lengths_ptr = lengths.data<T>();
+ auto* out = Output(0);
+
+ if (return_presence_mask_) {
+ CAFFE_THROW("CUDA version of PackSegments does not support presence mask.");
+ }
+ CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
+ CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
+
+ // Find the length of the longest sequence.
+ dev_max_length_.Resize(1);
+ host_max_length_.Resize(1);
+ const T max_length = array_max<T>(
+ lengths_ptr,
+ num_seq,
+ dev_max_length_buffer_,
+ dev_max_length_,
+ host_max_length_,
+ context_);
+
+ // Compute prefix sum over the lengths
+ array_prefix_sum_exclusive<T>(
+ lengths_ptr,
+ num_seq,
+ lengths_prefix_sum_buffer_,
+ lengths_prefix_sum_,
+ context_);
+
+ // create output tensor
+ auto shape = data.dims(); // Shape of out is batch_size x max_len x ...
+ shape[0] = max_length;
+ shape.insert(shape.begin(), lengths.size());
+ out->Resize(shape);
+ Data_T* out_ptr = static_cast<Data_T*>(out->raw_mutable_data(data.meta()));
+
+ // Return empty out (with the proper shape) if first dim is 0.
+ if (!data.dim(0)) {
+ return true;
+ }
+
+ // Do padding
+ Data_T padding = out->IsType<float>() ? padding_ : 0;
+ int64_t cell_size = data.size() / data.dim(0);
+ PackSegmentsKernel<<<
+ CAFFE_GET_BLOCKS(num_seq * max_length * cell_size),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ data_ptr,
+ lengths_ptr,
+ lengths_prefix_sum_.data<T>(),
+ max_length,
+ num_seq,
+ cell_size,
+ padding,
+ out_ptr);
+ return true;
+}
+
+REGISTER_CUDA_OPERATOR(PackSegments, PackSegmentsOp<CUDAContext>);
+} // namespace caffe2
diff --git a/caffe2/operators/pack_segments.h b/caffe2/operators/pack_segments.h
index 899fb04..d4bf1b5 100644
--- a/caffe2/operators/pack_segments.h
+++ b/caffe2/operators/pack_segments.h
@@ -48,81 +48,15 @@
}
}
- bool RunOnDevice() override {
+ bool RunOnDevice() {
return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
}
template <typename T>
- bool DoRunWithType() {
- const auto& data = Input(DATA);
- const auto& lengths = Input(LENGTHS);
- auto* output = Output(0);
- Tensor<Context>* presence_mask = nullptr;
- if (return_presence_mask_) {
- presence_mask = Output(1);
- }
+ bool DoRunWithType();
- CAFFE_ENFORCE(data.ndim() >= 1, "DATA should be at least 1-D");
- CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
-
- // Find the length of the longest sequence.
- const T* l = lengths.template data<T>();
- T max_length = 0;
- for (T i = 0; i < lengths.dim(0); ++i) {
- max_length = std::max(max_length, l[i]);
- }
-
- auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
- shape[0] = max_length;
- shape.insert(shape.begin(), lengths.size());
- output->Resize(shape);
-
- // create output tensor
- auto* out = static_cast<char*>(output->raw_mutable_data(data.meta()));
-
- bool* presence_mask_data = nullptr;
- if (return_presence_mask_) {
- // Shape of presence is batch_size x max_len
- std::vector<caffe2::TIndex> presence_shape{lengths.size(), max_length};
- presence_mask->Resize(presence_shape);
- presence_mask_data = presence_mask->template mutable_data<bool>();
- }
-
- if (!data.dim(0)) {
- // Return empty output (with the proper shape)
- return true;
- }
-
- // Do padding
- if (output->template IsType<float>()) {
- math::Set<float, Context>(
- output->size(),
- padding_,
- output->template mutable_data<float>(),
- &context_);
- }
- if (return_presence_mask_) {
- memset(presence_mask_data, (int)false, presence_mask->size());
- }
-
- int block_size = data.size() / data.dim(0);
- int block_bytesize = data.nbytes() / data.dim(0);
- const auto* d = static_cast<const char*>(data.raw_data());
- int start = 0;
- for (int i = 0; i < lengths.dim(0); ++i) {
- context_.template CopyItems<Context, Context>(
- data.meta(),
- l[i] * block_size,
- d + block_bytesize * start,
- out + block_bytesize * max_length * i);
- if (return_presence_mask_) {
- memset(presence_mask_data + max_length * i, (int)true, l[i]);
- }
- start += l[i];
- }
-
- return true;
- }
+ template <typename T, typename Data_T>
+ bool DoRunWithType2();
INPUT_TAGS(LENGTHS, DATA);
@@ -130,6 +64,13 @@
bool pad_minf_;
float padding_;
bool return_presence_mask_;
+
+ // Scratch space required by the CUDA version
+ Tensor<Context> lengths_prefix_sum_buffer_;
+ Tensor<Context> lengths_prefix_sum_;
+ Tensor<Context> dev_max_length_buffer_;
+ Tensor<Context> dev_max_length_;
+ Tensor<CPUContext> host_max_length_;
};
template <class Context>
diff --git a/caffe2/operators/pack_segments_op_gpu.cc b/caffe2/operators/pack_segments_op_gpu.cc
index f4b1f53..699186d 100644
--- a/caffe2/operators/pack_segments_op_gpu.cc
+++ b/caffe2/operators/pack_segments_op_gpu.cc
@@ -19,7 +19,6 @@
#include "caffe2/operators/pack_segments.h"
namespace caffe2 {
-REGISTER_CUDA_OPERATOR(PackSegments, GPUFallbackOp<PackSegmentsOp<CPUContext>>);
REGISTER_CUDA_OPERATOR(
UnpackSegments,
GPUFallbackOp<UnpackSegmentsOp<CPUContext>>);
diff --git a/caffe2/python/operator_test/pack_ops_test.py b/caffe2/python/operator_test/pack_ops_test.py
index 73c0a71..5e1567a 100644
--- a/caffe2/python/operator_test/pack_ops_test.py
+++ b/caffe2/python/operator_test/pack_ops_test.py
@@ -22,7 +22,9 @@
import caffe2.python.hypothesis_test_util as hu
from hypothesis import given
+from hypothesis import strategies as st
import numpy as np
+import time
class TestTensorPackOps(hu.HypothesisTestCase):
@@ -62,37 +64,61 @@
return pack_segments_ref
- @given(**hu.gcs)
- def test_pack_ops(self, gc, dc):
- lengths = np.array([1, 2, 3], dtype=np.int32)
- data = np.array([
- [1.0, 1.0],
- [2.0, 2.0],
- [2.0, 2.0],
- [3.0, 3.0],
- [3.0, 3.0],
- [3.0, 3.0]], dtype=np.float32)
+ @given(
+ num_seq=st.integers(10, 1500),
+ cell_size=st.integers(1, 100),
+ **hu.gcs
+ )
+ def test_pack_ops(self, num_seq, cell_size, gc, dc):
+ # create data
+ lengths = np.arange(num_seq, dtype=np.int32) + 1
+ num_cell = np.sum(lengths)
+ data = np.zeros(num_cell * cell_size, dtype=np.float32)
+ left = np.cumsum(np.arange(num_seq) * cell_size)
+ right = np.cumsum(lengths * cell_size)
+ for i in range(num_seq):
+ data[left[i]:right[i]] = i + 1.0
+ data.resize(num_cell, cell_size)
+ print("\nnum seq:{}, num cell: {}, cell size:{}\n".format(
+ num_seq, num_cell, cell_size)
+ + "=" * 60
+ )
+ # run test
op = core.CreateOperator(
'PackSegments', ['l', 'd'], ['t'])
- print(gc, dc)
-
workspace.FeedBlob('l', lengths)
workspace.FeedBlob('d', data)
- inputs = [lengths, data]
+
+ start = time.time()
self.assertReferenceChecks(
device_option=gc,
op=op,
- inputs=inputs,
+ inputs=[lengths, data],
reference=self.pack_segments_ref(),
)
- workspace.FeedBlob('l', lengths)
- workspace.FeedBlob('d', data)
+ end = time.time()
+ print("{} used time: {}".format(gc, end - start).replace('\n', ' '))
+ with core.DeviceScope(gc):
+ workspace.FeedBlob('l', lengths)
+ workspace.FeedBlob('d', data)
workspace.RunOperatorOnce(core.CreateOperator(
- 'PackSegments', ['l', 'd'], ['t']))
+ 'PackSegments',
+ ['l', 'd'],
+ ['t'],
+ device_option=gc))
workspace.RunOperatorOnce(core.CreateOperator(
- 'UnpackSegments', ['l', 't'], ['newd']))
+ 'UnpackSegments',
+ ['l', 't'],
+ ['newd'],
+ device_option=gc))
assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())
+
+ @given(
+ **hu.gcs_cpu_only
+ )
+ def test_pack_ops_str(self, gc, dc):
+ # GPU does not support string. Test CPU implementation only.
workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int64))
strs = np.array([
["a", "a"],
@@ -104,9 +130,15 @@
dtype='|S')
workspace.FeedBlob('d', strs)
workspace.RunOperatorOnce(core.CreateOperator(
- 'PackSegments', ['l', 'd'], ['t']))
+ 'PackSegments',
+ ['l', 'd'],
+ ['t'],
+ device_option=gc))
workspace.RunOperatorOnce(core.CreateOperator(
- 'UnpackSegments', ['l', 't'], ['newd']))
+ 'UnpackSegments',
+ ['l', 't'],
+ ['newd'],
+ device_option=gc))
assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())
def test_pad_minf(self):
@@ -134,7 +166,7 @@
exponentiated = workspace.FetchBlob('r')
assert(exponentiated[0, -1, 0] == 0.0)
- @given(**hu.gcs)
+ @given(**hu.gcs_cpu_only)
def test_presence_mask(self, gc, dc):
lengths = np.array([1, 2, 3], dtype=np.int32)
data = np.array(
@@ -161,8 +193,6 @@
op = core.CreateOperator(
'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True
)
- workspace.FeedBlob('l', lengths)
- workspace.FeedBlob('d', data)
workspace.RunOperatorOnce(op)
output = workspace.FetchBlob('t')
@@ -177,8 +207,7 @@
self.assertEqual(presence_mask.shape, expected_presence_mask.shape)
np.testing.assert_array_equal(presence_mask, expected_presence_mask)
- @given(**hu.gcs)
- def test_presence_mask_empty(self, gc, dc):
+ def test_presence_mask_empty(self):
lengths = np.array([], dtype=np.int32)
data = np.array([], dtype=np.float32)