cuda unique op

Summary:
cuda unique op , unittest provided, will provide benchmark agains CPU

SpeedUp results for synthetic real data. Input of size 20k, range[1, 10million], **~5x** speedup

  CPU 9.05795(ms) Unique
  GPU 1.79434(ms) Unique

SpeedUp results for 5x synthetic data. Input of  size 1 million, range[1, 10million] **~13.7x** speedup

  CPU 54.7539(ms) Unique
  GPU 3.99473(ms) Unique

Reviewed By: akyrola

Differential Revision: D5007726

fbshipit-source-id: 0a00c518fd1809d0ae8c6cfcba09b0bd982ffaff
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 99f0f20..9e0b149 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -9,6 +9,48 @@
   return DoRunWithType<float>();
 }
 
+template <>
+template <typename T>
+void UniqueOp<CPUContext>::DoRun() {
+  auto& inputTensor = Input(0);
+  // use dim32 to enforce that it's fine to have remapping of type int
+  int N = inputTensor.dim32(0);
+  CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
+  auto* uniqueTensor = Output(UNIQUE);
+
+  int* remapping = nullptr;
+  if (REMAPPING < OutputSize()) {
+    auto* remappingTensor = Output(REMAPPING);
+    remappingTensor->ResizeLike(inputTensor);
+    remapping = remappingTensor->template mutable_data<int>();
+  }
+
+  const T* input = inputTensor.template data<T>();
+  // TODO(dzhulgakov): if perf becomes an issue consider doing hash table
+  // instead of sorting
+  order_.resize(N);
+  std::iota(order_.begin(), order_.end(), 0);
+  std::sort(order_.begin(), order_.end(), [input](const int x, const int y) {
+    return input[x] < input[y];
+  });
+  int K = N;
+  for (int i = 1; i < N; ++i) {
+    K -= input[order_[i]] == input[order_[i - 1]];
+  }
+  uniqueTensor->Resize(K);
+  T* unique = uniqueTensor->template mutable_data<T>();
+  K = 0;
+  T prev = -1;
+  for (int i = 0; i < N; ++i) {
+    if (i == 0 || prev != input[order_[i]]) {
+      prev = unique[K++] = input[order_[i]];
+    }
+    if (remapping) {
+      remapping[order_[i]] = K - 1;
+    }
+  }
+}
+
 namespace {
 
 REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 7a28fe6..9cdf2eb 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -4,6 +4,11 @@
 // and std::isinf are declared constexpr there and the nvidia
 // compiler throws an error because of it
 
+#include <thrust/device_vector.h>
+#include <thrust/sequence.h>
+#include <thrust/sort.h>
+#include <thrust/system/cuda/execution_policy.h>
+#include <thrust/unique.h>
 #include "caffe2/core/context_gpu.h"
 #include "utility_ops.h"
 
@@ -289,4 +294,106 @@
     ScatterWeightedSum,
     ScatterWeightedSumOp<float, CUDAContext>);
 
+#if THRUST_VERSION >= 100800
+__global__ void remap_kernel(
+    thrust::device_ptr<int> second_order,
+    thrust::device_ptr<int> order,
+    int* output,
+    int N,
+    int K) {
+  int i = blockDim.x * blockIdx.x + threadIdx.x;
+  if (i >= K)
+    return;
+  int idx = second_order[i];
+  output[order[idx]] = i;
+  // Maybe cuda 1D kernel?
+  for (idx++; idx < N && (i == K - 1 || idx != second_order[i + 1]); idx++) {
+    output[order[idx]] = i;
+  }
+  return;
+}
+
+template <>
+template <typename T>
+void UniqueOp<CUDAContext>::DoRun() {
+  auto& inputTensor = Input(0);
+  // use dim32 to enforce that it's fine to have remapping of type int
+  int N = inputTensor.dim32(0);
+  CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
+  auto* uniqueTensor = Output(UNIQUE);
+
+  int* remapping = nullptr;
+  if (REMAPPING < OutputSize()) {
+    auto* remappingTensor = Output(REMAPPING);
+    remappingTensor->ResizeLike(inputTensor);
+    remapping = remappingTensor->template mutable_data<int>();
+  }
+
+  const T* input = inputTensor.template data<T>();
+  thrust_unique_buffer_.Resize(N);
+  auto* buffer = thrust_unique_buffer_.template mutable_data<T>();
+  context_.template CopyItems<CUDAContext, CUDAContext>(
+      inputTensor.meta(), N, input, buffer);
+
+  // Create two vector of {0, 1, ..., N-1} on CUDA device
+  thrust::device_vector<int> order1(N), order2(N);
+  thrust::sequence(
+      thrust::cuda::par.on(context_.cuda_stream()),
+      order1.begin(),
+      order1.end());
+  thrust::sequence(
+      thrust::cuda::par.on(context_.cuda_stream()),
+      order2.begin(),
+      order2.end());
+
+  // Sort the input along with order vector. So now we know where each element
+  // is permutated to. For example:
+  //    input1 = 1,3,5,1,5,7,9
+  //    order1 = 0,1,2,3,4,5,6
+  // Now we have:
+  //    output = 1,1,3,5,5,7,9
+  //    order1 = 0,3,1,2,4,5,6
+  thrust::sort_by_key(
+      thrust::cuda::par.on(context_.cuda_stream()),
+      buffer,
+      buffer + N,
+      order1.begin());
+
+  // Use consequent unique op to get another order_buffer
+  //    input2 = 1,1,3,5,5,7,9
+  //    order2 = 0,1,2,3,4,5,6
+  // Now we have:
+  //    output = 1,3,5,7,9
+  //    order2 = 0,2,3,5,6
+  auto new_last = thrust::unique_by_key(
+      thrust::cuda::par.on(context_.cuda_stream()),
+      buffer,
+      buffer + N,
+      order2.begin());
+  int K = new_last.first - buffer;
+
+  uniqueTensor->Resize(K);
+  T* unique = uniqueTensor->template mutable_data<T>();
+  context_.template CopyItems<CUDAContext, CUDAContext>(
+      thrust_unique_buffer_.meta(), K, buffer, unique);
+
+  // Compute the remapping. For example, for the number 1, if we look at
+  // order2[0] and order2[1], we know that input2[0:2) are all 1. They are all
+  // remapped to 0 in final input. And from order1, we know where they come
+  // from. The rest is easy.
+  if (remapping != nullptr) {
+    // record remap
+    remap_kernel<<<
+        CAFFE_GET_BLOCKS(K),
+        CAFFE_CUDA_NUM_THREADS,
+        0,
+        context_.cuda_stream()>>>(
+        order2.data(), order1.data(), remapping, N, K);
+  }
+}
+namespace {
+REGISTER_CUDA_OPERATOR(Unique, UniqueOp<CUDAContext>);
+} // namespace
+#endif // THRUST_VERSION >= 100800
+
 }  // namespace caffe2
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 6cf0270..7d2539b 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -1382,47 +1382,12 @@
 
  private:
   vector<int> order_;
+  Tensor<Context> thrust_unique_buffer_;
+  Tensor<Context> cuda_order_buffer_;
+  Tensor<Context> second_order_buffer_;
 
   template <typename T>
-  void DoRun() {
-    auto& inputTensor = Input(0);
-    // use dim32 to enforce that it's fine to have remapping of type int
-    int N = inputTensor.dim32(0);
-    CAFFE_ENFORCE_EQ(inputTensor.ndim(), 1, "Input should be a vector");
-    auto* uniqueTensor = Output(UNIQUE);
-
-    int* remapping = nullptr;
-    if (REMAPPING < OutputSize()) {
-      auto* remappingTensor = Output(REMAPPING);
-      remappingTensor->ResizeLike(inputTensor);
-      remapping = remappingTensor->template mutable_data<int>();
-    }
-
-    const T* input = inputTensor.template data<T>();
-    // TODO(dzhulgakov): if perf becomes an issue consider doing hash table
-    // instead of sorting
-    order_.resize(N);
-    std::iota(order_.begin(), order_.end(), 0);
-    std::sort(order_.begin(), order_.end(), [input](const int x, const int y) {
-      return input[x] < input[y];
-    });
-    int K = N;
-    for (int i = 1; i < N; ++i) {
-      K -= input[order_[i]] == input[order_[i - 1]];
-    }
-    uniqueTensor->Resize(K);
-    T* unique = uniqueTensor->template mutable_data<T>();
-    K = 0;
-    T prev = -1;
-    for (int i = 0; i < N; ++i) {
-      if (i == 0 || prev != input[order_[i]]) {
-        prev = unique[K++] = input[order_[i]];
-      }
-      if (remapping) {
-        remapping[order_[i]] = K - 1;
-      }
-    }
-  }
+  void DoRun();
 
  public:
   OUTPUT_TAGS(UNIQUE, REMAPPING);
diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py
index 564c748..10ffddc 100644
--- a/caffe2/python/hypothesis_test.py
+++ b/caffe2/python/hypothesis_test.py
@@ -906,7 +906,7 @@
                            dtype=np.int32,
                            elements=st.integers(min_value=0, max_value=10)),
            with_remapping=st.booleans(),
-           **hu.gcs_cpu_only)
+           **hu.gcs)
     def test_unique(self, input, with_remapping, gc, dc):
         op = core.CreateOperator(
             "Unique",