implement SliceOp for GPU
Summary: Implementation of the SliceOp for CUDA
Reviewed By: akyrola
Differential Revision: D5254287
fbshipit-source-id: 0a1660e1aa161fd088a2d8f886e019c05a1919a2
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 713b2bd..05daae8 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -15,6 +15,146 @@
namespace caffe2 {
CAFFE_KNOWN_TYPE(const float*);
+namespace {
+__global__ void SliceCopyKernel(
+ char* src_offset_bytes,
+ int src_block_size_bytes,
+ char* dst_offset_bytes,
+ int dst_block_size_bytes,
+ int itemsize,
+ int num_blocks) {
+ CUDA_1D_KERNEL_LOOP(index, num_blocks) {
+ char* local_src_offset_bytes =
+ src_offset_bytes + index * src_block_size_bytes;
+ char* local_dst_offset_bytes =
+ dst_offset_bytes + index * dst_block_size_bytes;
+ memcpy(
+ local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes);
+ }
+}
+} // namespace
+
+template <>
+bool SliceOp<int, CUDAContext>::RunOnDevice() {
+ auto* output = Output(0);
+ auto& data = Input(0);
+
+ auto& starts = Input(1);
+ auto& ends = Input(2);
+
+ CAFFE_ENFORCE_EQ(starts.ndim(), 1);
+ CAFFE_ENFORCE_EQ(ends.ndim(), 1);
+ CAFFE_ENFORCE_GE(data.ndim(), starts.size());
+ CAFFE_ENFORCE_EQ(starts.size(), ends.size());
+
+ TensorCPU starts_host;
+ TensorCPU ends_host;
+ starts_host.template CopyFrom<CUDAContext>(starts);
+ ends_host.template CopyFrom<CUDAContext>(ends);
+
+ auto* starts_data_host = starts_host.template data<int>();
+ auto* ends_data_host = ends_host.template data<int>();
+
+ std::vector<int> starts_idx(data.ndim());
+ std::vector<int> ends_idx(data.ndim());
+ std::vector<int> dst_sizes(data.ndim());
+
+ for (int i = 0; i < data.ndim(); ++i) {
+ if (i >= starts.size()) {
+ starts_idx[i] = 0;
+ ends_idx[i] = data.dims()[i];
+ continue;
+ }
+ if (data.dims()[i] > 0) {
+ auto start = starts_data_host[i];
+ auto end = ends_data_host[i];
+ if (start < 0) {
+ start = data.dims()[i] + 1 + start;
+ }
+ if (end < 0) {
+ end = data.dims()[i] + 1 + end;
+ }
+ CAFFE_ENFORCE_GE(start, 0);
+ CAFFE_ENFORCE_GE(end, 0);
+ CAFFE_ENFORCE_LT(start, data.dims()[i]);
+ CAFFE_ENFORCE_LE(end, data.dims()[i]);
+ CAFFE_ENFORCE_GE(end, start);
+ starts_idx[i] = start;
+ ends_idx[i] = end;
+ dst_sizes[i] = end - start;
+ } else {
+ starts_idx[i] = 0;
+ ends_idx[i] = 0;
+ dst_sizes[i] = 0;
+ }
+ }
+
+ if (data.size() <= 0) {
+ // When the input is empty, we do not need to do copy.
+ output->Resize(dst_sizes);
+ output->raw_mutable_data(data.meta());
+ return true;
+ }
+ // for now only supports slicing in 1 dimension
+ int dim = -1;
+ for (int i = 0; i < data.ndim(); ++i) {
+ if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) {
+ CAFFE_ENFORCE_EQ(
+ dim, -1, "Currently only possible to slice in 1 dimension.");
+ dim = i;
+ }
+ }
+ if (dim == -1) {
+ output->CopyFrom(data, &context_);
+ return true;
+ }
+ auto unit = std::accumulate(
+ data.dims().begin() + dim + 1,
+ data.dims().end(),
+ 1,
+ std::multiplies<int>());
+ auto num_blocks = std::accumulate(
+ data.dims().begin(),
+ data.dims().begin() + dim,
+ 1,
+ std::multiplies<int>());
+ output->Resize(dst_sizes);
+ auto* src_bytes = (char*)data.raw_data();
+ auto* dst_bytes = (char*)output->raw_mutable_data(data.meta());
+
+ auto src_nbytes = data.nbytes();
+ auto dst_nbytes = output->nbytes();
+
+ auto src_block_size = unit * data.dims()[dim];
+ auto dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
+ auto src_offset = unit * starts_idx[dim];
+
+ if (num_blocks == 0 || dst_block_size == 0) {
+ return true;
+ }
+
+ auto itemsize = data.meta().itemsize();
+ auto src_block_size_bytes = itemsize * src_block_size;
+ auto dst_block_size_bytes = itemsize * dst_block_size;
+ auto src_offset_bytes = src_bytes + itemsize * src_offset;
+ auto dst_offset_bytes = dst_bytes;
+
+ SliceCopyKernel<<<
+ std::min(num_blocks, CAFFE_MAXIMUM_NUM_BLOCKS),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ src_offset_bytes,
+ src_block_size_bytes,
+ dst_offset_bytes,
+ dst_block_size_bytes,
+ itemsize,
+ num_blocks);
+ return true;
+}
+
+REGISTER_CUDA_OPERATOR(Slice, SliceOp<int, CUDAContext>);
+
__global__ void NanCheckKernel(int N, const float* X, bool* result) {
bool has_nan = false;
CUDA_1D_KERNEL_LOOP(i, N) {
diff --git a/caffe2/python/operator_test/utility_ops_test.py b/caffe2/python/operator_test/utility_ops_test.py
index 1ac3e70..61aa64b 100644
--- a/caffe2/python/operator_test/utility_ops_test.py
+++ b/caffe2/python/operator_test/utility_ops_test.py
@@ -8,10 +8,35 @@
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
+import random
class TestUtilityOps(hu.HypothesisTestCase):
+ @given(X=hu.tensor(), neg=st.booleans(), **hu.gcs)
+ def test_slice(self, X, neg, gc, dc):
+ X = X.astype(dtype=np.float32)
+ dim = random.randint(0, X.ndim - 1)
+ slice_start = random.randint(0, X.shape[dim] - 1)
+ slice_end = random.randint(slice_start, X.shape[dim] - 1)
+ starts = np.array([0] * X.ndim).astype(np.int32)
+ ends = np.array([-1] * X.ndim).astype(np.int32)
+ starts[dim] = slice_start
+ ends[dim] = slice_end
+
+ op = core.CreateOperator(
+ "Slice", ["X", "starts", "ends"], ["Y"], device_option=gc
+ )
+
+ def slice_ref(X, starts, ends):
+ slc = [slice(None)] * X.ndim
+ slc[dim] = slice(slice_start, slice_end)
+ return [X[slc]]
+
+ self.assertReferenceChecks(gc, op, [X, starts, ends], slice_ref)
+
+ self.assertDeviceChecks(dc, op, [X, starts, ends], [0])
+
@given(dtype=st.sampled_from([np.float32, np.int32, np.int64]),
ndims=st.integers(min_value=1, max_value=5),
seed=st.integers(min_value=0, max_value=65536),