cuda unique op
Summary:
cuda unique op , unittest provided, will provide benchmark agains CPU
SpeedUp results for synthetic real data. Input of size 20k, range[1, 10million], **~5x** speedup
CPU 9.05795(ms) Unique
GPU 1.79434(ms) Unique
SpeedUp results for 5x synthetic data. Input of size 1 million, range[1, 10million] **~13.7x** speedup
CPU 54.7539(ms) Unique
GPU 3.99473(ms) Unique
Reviewed By: akyrola
Differential Revision: D5007726
fbshipit-source-id: 0a00c518fd1809d0ae8c6cfcba09b0bd982ffaff
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 99f0f20..9e0b149 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -9,6 +9,48 @@
return DoRunWithType<float>();
}
+template <>
+template <typename T>
+void UniqueOp<CPUContext>::DoRun() {
+ auto& inputTensor = Input(0);
+ // use dim32 to enforce that it's fine to have remapping of type int
+ int N = inputTensor.dim32(0);
+ CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
+ auto* uniqueTensor = Output(UNIQUE);
+
+ int* remapping = nullptr;
+ if (REMAPPING < OutputSize()) {
+ auto* remappingTensor = Output(REMAPPING);
+ remappingTensor->ResizeLike(inputTensor);
+ remapping = remappingTensor->template mutable_data<int>();
+ }
+
+ const T* input = inputTensor.template data<T>();
+ // TODO(dzhulgakov): if perf becomes an issue consider doing hash table
+ // instead of sorting
+ order_.resize(N);
+ std::iota(order_.begin(), order_.end(), 0);
+ std::sort(order_.begin(), order_.end(), [input](const int x, const int y) {
+ return input[x] < input[y];
+ });
+ int K = N;
+ for (int i = 1; i < N; ++i) {
+ K -= input[order_[i]] == input[order_[i - 1]];
+ }
+ uniqueTensor->Resize(K);
+ T* unique = uniqueTensor->template mutable_data<T>();
+ K = 0;
+ T prev = -1;
+ for (int i = 0; i < N; ++i) {
+ if (i == 0 || prev != input[order_[i]]) {
+ prev = unique[K++] = input[order_[i]];
+ }
+ if (remapping) {
+ remapping[order_[i]] = K - 1;
+ }
+ }
+}
+
namespace {
REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 7a28fe6..9cdf2eb 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -4,6 +4,11 @@
// and std::isinf are declared constexpr there and the nvidia
// compiler throws an error because of it
+#include <thrust/device_vector.h>
+#include <thrust/sequence.h>
+#include <thrust/sort.h>
+#include <thrust/system/cuda/execution_policy.h>
+#include <thrust/unique.h>
#include "caffe2/core/context_gpu.h"
#include "utility_ops.h"
@@ -289,4 +294,106 @@
ScatterWeightedSum,
ScatterWeightedSumOp<float, CUDAContext>);
+#if THRUST_VERSION >= 100800
+__global__ void remap_kernel(
+ thrust::device_ptr<int> second_order,
+ thrust::device_ptr<int> order,
+ int* output,
+ int N,
+ int K) {
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i >= K)
+ return;
+ int idx = second_order[i];
+ output[order[idx]] = i;
+ // Maybe cuda 1D kernel?
+ for (idx++; idx < N && (i == K - 1 || idx != second_order[i + 1]); idx++) {
+ output[order[idx]] = i;
+ }
+ return;
+}
+
+template <>
+template <typename T>
+void UniqueOp<CUDAContext>::DoRun() {
+ auto& inputTensor = Input(0);
+ // use dim32 to enforce that it's fine to have remapping of type int
+ int N = inputTensor.dim32(0);
+ CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
+ auto* uniqueTensor = Output(UNIQUE);
+
+ int* remapping = nullptr;
+ if (REMAPPING < OutputSize()) {
+ auto* remappingTensor = Output(REMAPPING);
+ remappingTensor->ResizeLike(inputTensor);
+ remapping = remappingTensor->template mutable_data<int>();
+ }
+
+ const T* input = inputTensor.template data<T>();
+ thrust_unique_buffer_.Resize(N);
+ auto* buffer = thrust_unique_buffer_.template mutable_data<T>();
+ context_.template CopyItems<CUDAContext, CUDAContext>(
+ inputTensor.meta(), N, input, buffer);
+
+ // Create two vector of {0, 1, ..., N-1} on CUDA device
+ thrust::device_vector<int> order1(N), order2(N);
+ thrust::sequence(
+ thrust::cuda::par.on(context_.cuda_stream()),
+ order1.begin(),
+ order1.end());
+ thrust::sequence(
+ thrust::cuda::par.on(context_.cuda_stream()),
+ order2.begin(),
+ order2.end());
+
+ // Sort the input along with order vector. So now we know where each element
+ // is permutated to. For example:
+ // input1 = 1,3,5,1,5,7,9
+ // order1 = 0,1,2,3,4,5,6
+ // Now we have:
+ // output = 1,1,3,5,5,7,9
+ // order1 = 0,3,1,2,4,5,6
+ thrust::sort_by_key(
+ thrust::cuda::par.on(context_.cuda_stream()),
+ buffer,
+ buffer + N,
+ order1.begin());
+
+ // Use consequent unique op to get another order_buffer
+ // input2 = 1,1,3,5,5,7,9
+ // order2 = 0,1,2,3,4,5,6
+ // Now we have:
+ // output = 1,3,5,7,9
+ // order2 = 0,2,3,5,6
+ auto new_last = thrust::unique_by_key(
+ thrust::cuda::par.on(context_.cuda_stream()),
+ buffer,
+ buffer + N,
+ order2.begin());
+ int K = new_last.first - buffer;
+
+ uniqueTensor->Resize(K);
+ T* unique = uniqueTensor->template mutable_data<T>();
+ context_.template CopyItems<CUDAContext, CUDAContext>(
+ thrust_unique_buffer_.meta(), K, buffer, unique);
+
+ // Compute the remapping. For example, for the number 1, if we look at
+ // order2[0] and order2[1], we know that input2[0:2) are all 1. They are all
+ // remapped to 0 in final input. And from order1, we know where they come
+ // from. The rest is easy.
+ if (remapping != nullptr) {
+ // record remap
+ remap_kernel<<<
+ CAFFE_GET_BLOCKS(K),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ order2.data(), order1.data(), remapping, N, K);
+ }
+}
+namespace {
+REGISTER_CUDA_OPERATOR(Unique, UniqueOp<CUDAContext>);
+} // namespace
+#endif // THRUST_VERSION >= 100800
+
} // namespace caffe2
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 6cf0270..7d2539b 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -1382,47 +1382,12 @@
private:
vector<int> order_;
+ Tensor<Context> thrust_unique_buffer_;
+ Tensor<Context> cuda_order_buffer_;
+ Tensor<Context> second_order_buffer_;
template <typename T>
- void DoRun() {
- auto& inputTensor = Input(0);
- // use dim32 to enforce that it's fine to have remapping of type int
- int N = inputTensor.dim32(0);
- CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
- auto* uniqueTensor = Output(UNIQUE);
-
- int* remapping = nullptr;
- if (REMAPPING < OutputSize()) {
- auto* remappingTensor = Output(REMAPPING);
- remappingTensor->ResizeLike(inputTensor);
- remapping = remappingTensor->template mutable_data<int>();
- }
-
- const T* input = inputTensor.template data<T>();
- // TODO(dzhulgakov): if perf becomes an issue consider doing hash table
- // instead of sorting
- order_.resize(N);
- std::iota(order_.begin(), order_.end(), 0);
- std::sort(order_.begin(), order_.end(), [input](const int x, const int y) {
- return input[x] < input[y];
- });
- int K = N;
- for (int i = 1; i < N; ++i) {
- K -= input[order_[i]] == input[order_[i - 1]];
- }
- uniqueTensor->Resize(K);
- T* unique = uniqueTensor->template mutable_data<T>();
- K = 0;
- T prev = -1;
- for (int i = 0; i < N; ++i) {
- if (i == 0 || prev != input[order_[i]]) {
- prev = unique[K++] = input[order_[i]];
- }
- if (remapping) {
- remapping[order_[i]] = K - 1;
- }
- }
- }
+ void DoRun();
public:
OUTPUT_TAGS(UNIQUE, REMAPPING);
diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py
index 564c748..10ffddc 100644
--- a/caffe2/python/hypothesis_test.py
+++ b/caffe2/python/hypothesis_test.py
@@ -906,7 +906,7 @@
dtype=np.int32,
elements=st.integers(min_value=0, max_value=10)),
with_remapping=st.booleans(),
- **hu.gcs_cpu_only)
+ **hu.gcs)
def test_unique(self, input, with_remapping, gc, dc):
op = core.CreateOperator(
"Unique",