Adds `dim` argument to `torch.unique` (#10423)

Summary:
Initial version of `unique` supporting a `dim` argument.

As discussed in [this issue](https://github.com/pytorch/pytorch/issues/9997) I added the `dim` argument to `torch.unique` with the same behavior like [numpy](https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.unique.html).

Since the implementation is based on `std/thrust::unique`, the `tensor` always needs to be sorted. The `sorted` argument in `torch.unique` does not have any function, as in the CUDA version of the plain `torch.unique`.

To check the performance and equal behavior between `torch.unique` and `np.unique`, I've used [this gist](https://gist.github.com/ptrblck/ac0dc862f4e1766f0e1036c252cdb105).

Currently we achieve the following timings for an input of `x = torch.randint(2, (1000, 1000))`:
(The values are calculated by taking the average of the times for both dimension)

| Device | PyTorch (return_inverse=False) | Numpy (return_inverse=False) | PyTorch (return_inverse=True) | Numpy (return_inverse=True) |
| --- | --- | --- | --- | --- |
| CPU | ~0.007331s | ~0.022452s | ~0.011139s | ~0.044800s |
| GPU | ~0.006154s | - | ~0.105373s | - |

Many thanks to colesbury for the awesome mentoring and the valuable advices on the general implementation and performance issues!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10423

Differential Revision: D9517289

Pulled By: soumith

fbshipit-source-id: a4754f805223589c2847c98b8e4e39d8c3ddb7b5
diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp
index d9bd94e..d5ff300 100644
--- a/aten/src/ATen/native/Unique.cpp
+++ b/aten/src/ATen/native/Unique.cpp
@@ -47,6 +47,82 @@
   }
   return std::make_tuple(output, inverse_indices);
 }
+
+template<class ForwardIt>
+ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
+  std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
+    if (first == last) {
+      return last;
+    }
+    // save to calculate distance to iterators
+    ForwardIt begin = first;
+
+    // set first inverse index
+    inverse_indices_vec[indices[0]] = 0;
+
+    ForwardIt result = first;
+    while (++first != last) {
+      if (!at::equal(*result, *first) && ++result != first) {
+          *result = std::move(*first);
+      }
+      int64_t idx_result = std::distance(begin, result);
+      int64_t idx_first = std::distance(begin, first);
+      inverse_indices_vec[indices[idx_first]] = idx_result;
+    }
+
+    return ++result;
+  }
+
+template <typename scalar_t>
+std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
+    const Tensor& self,
+    const int64_t dim,
+    const bool return_inverse) {
+  // reshape tensor as [dim, -1]
+  Tensor input_flat = self.transpose(dim, 0);
+  auto orig_sizes = input_flat.sizes().vec();
+  input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
+
+  std::vector<int64_t> indices(input_flat.size(0));
+  std::iota(indices.begin(), indices.end(), 0);
+  int64_t numel = input_flat.size(1);
+  scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());
+
+  // sort indices using data
+  std::sort(indices.begin(), indices.end(),
+    [&](int64_t a, int64_t b) -> bool {
+      for (int64_t i = 0; i < numel; ++i) {
+        scalar_t lhs = input_flat_ptr[i + a * numel];
+        scalar_t rhs = input_flat_ptr[i + b * numel];
+        if (lhs < rhs) {
+          return true;
+        } else if (lhs > rhs) {
+          return false;
+        }
+      }
+      return false;
+    });
+
+  Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type());
+  for (int i = 0; i < indices.size(); ++i) {
+    input_sorted[i] = input_flat[indices[i]];
+  }
+
+  Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong));
+  std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
+  auto last = _unique_dim_cpu_impl(
+    input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
+  input_unbind.erase(last, input_unbind.end());
+
+  // reshape back
+  auto output = at::stack(input_unbind, 0);
+  auto new_sizes = std::vector<int64_t>(orig_sizes);
+  new_sizes[0] = -1;
+  output = output.view(new_sizes);
+  output = output.transpose(0, dim);
+
+  return std::make_tuple(output, inverse_indices);
+}
 } // namespace
 
 std::tuple<Tensor, Tensor>
@@ -56,5 +132,13 @@
   });
 }
 
+std::tuple<Tensor, Tensor>
+_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+  return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+    // The current implementation using `dim` always sorts due to unhashable tensors
+    return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
+  });
+}
+
 }  // namespace native
 }  // namespace at
diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu
index f2e13b4..c29337f 100644
--- a/aten/src/ATen/native/cuda/Unique.cu
+++ b/aten/src/ATen/native/cuda/Unique.cu
@@ -69,6 +69,92 @@
     return std::tuple<Tensor, Tensor>(output, inverse_indices);
 
   }
+
+template <typename scalar_t>
+  std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
+    const Tensor& self,
+    const int64_t dim,
+    const bool return_inverse) {
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+    auto policy = thrust::cuda::par(allocator).on(stream);
+
+    Tensor input_flat = self.transpose(dim, 0);
+    auto orig_sizes = input_flat.sizes().vec();
+    input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
+
+    scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
+
+    Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
+    int64_t* indices_ptr = indices.data<int64_t>();
+    int64_t numel = input_flat.size(1);
+
+    // sort indices using data
+    thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
+      [=] __device__ (int64_t a, int64_t b) -> bool {
+        for (int64_t i = 0; i < numel; ++i) {
+          scalar_t lhs = input_flat_ptr[i + a * numel];
+          scalar_t rhs = input_flat_ptr[i + b * numel];
+          if (lhs < rhs) {
+            return true;
+          } else if (lhs > rhs) {
+            return false;
+          }
+        }
+        return false;
+      });
+
+    Tensor input_sorted = input_flat.index_select(0, indices);
+
+    // get unique tensors
+    scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();    
+    Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
+    int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
+    auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
+      [=] __device__ (int64_t a, int64_t b) -> bool {
+        for (int64_t i = 0; i < numel; ++i) {
+          scalar_t lhs = input_sorted_ptr[i + a * numel];
+          scalar_t rhs = input_sorted_ptr[i + b * numel];
+          if (lhs != rhs) {
+            return false;
+          }
+        }
+        return true;
+      });
+    input_sorted_indices.resize_(last - input_sorted_indices_ptr);
+    Tensor output = input_sorted.index_select(0, input_sorted_indices);
+
+    // reshape back
+    auto new_sizes = std::vector<int64_t>(orig_sizes);
+    new_sizes[0] = -1;
+    output = output.view(new_sizes);
+    output = output.transpose(0, dim);
+
+    // calculate inverse indices
+    Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
+    if (return_inverse) {
+      int64_t size = self.size(dim);
+      inverse_indices.resize_(size);
+      Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
+      mask[0] = 1;
+      for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
+        if (!at::equal(input_sorted[i], input_sorted[i+1])) {
+          mask[i+1] = 1; 
+        } else {
+          mask[i+1] = 0;
+        }
+      }
+
+      Tensor imask = at::cumsum(mask, 0) - 1;
+      for (int i = 0; i < indices.size(0); ++i) {
+        inverse_indices[indices[i]] = imask[i];
+      }
+    }
+
+    THCudaCheck(cudaGetLastError());  
+    return std::tuple<Tensor, Tensor>(output, inverse_indices);
+  }
 } // namespace
 
 #endif
@@ -86,5 +172,16 @@
 #endif
 }
 
+std::tuple<Tensor, Tensor>
+_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+  #ifndef __HIP_PLATFORM_HCC__
+    return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+      return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
+    });
+  #else
+    AT_ERROR("unique_dim_cuda: HIP not supported");
+  #endif
+}
+
 }  // namespace native
 }  // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 466fe6c..cb194cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1748,6 +1748,11 @@
     CPU: _unique_cpu
     CUDA: _unique_cuda
 
+- func: _unique_dim(Tensor self, int64_t dim, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor)
+  dispatch:
+    CPU: _unique_dim_cpu
+    CUDA: _unique_dim_cuda
+
 - func: _unsafe_view(Tensor self, IntList size) -> Tensor
   variants: function
 
diff --git a/test/test_torch.py b/test/test_torch.py
index 167a400..863f97f 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8485,6 +8485,67 @@
         self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
         self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
 
+    def test_unique_dim(self):
+        def run_test(dtype=torch.float):
+            x = torch.tensor([[[1., 1.],
+                               [0., 1.],
+                               [2., 1.],
+                               [0., 1.]],
+                              [[1., 1.],
+                               [0., 1.],
+                               [2., 1.],
+                               [0., 1.]]], dtype=dtype)
+            expected_unique_dim0 = torch.tensor([[[1., 1.],
+                                                  [0., 1.],
+                                                  [2., 1.],
+                                                  [0., 1.]]], dtype=dtype)
+            expected_inverse_dim0 = torch.tensor([0, 0])
+            expected_unique_dim1 = torch.tensor([[[0., 1.],
+                                                  [1., 1.],
+                                                  [2., 1.]],
+                                                 [[0., 1.],
+                                                  [1., 1.],
+                                                  [2., 1.]]], dtype=dtype)
+            expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
+            expected_unique_dim2 = torch.tensor([[[1., 1.],
+                                                  [0., 1.],
+                                                  [2., 1.],
+                                                  [0., 1.]],
+                                                 [[1., 1.],
+                                                  [0., 1.],
+                                                  [2., 1.],
+                                                  [0., 1.]]], dtype=dtype)
+            expected_inverse_dim2 = torch.tensor([0, 1])
+
+            # dim0
+            x_unique = torch.unique(x, dim=0)
+            self.assertEqual(expected_unique_dim0, x_unique)
+
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
+            self.assertEqual(expected_unique_dim0, x_unique)
+            self.assertEqual(expected_inverse_dim0, x_inverse)
+
+            # dim1
+            x_unique = torch.unique(x, dim=1)
+            self.assertEqual(expected_unique_dim1, x_unique)
+
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
+            self.assertEqual(expected_unique_dim1, x_unique)
+            self.assertEqual(expected_inverse_dim1, x_inverse)
+
+            # dim2
+            x_unique = torch.unique(x, dim=2)
+            self.assertEqual(expected_unique_dim2, x_unique)
+
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
+            self.assertEqual(expected_unique_dim2, x_unique)
+            self.assertEqual(expected_inverse_dim2, x_inverse)
+
+        run_test(torch.float)
+        run_test(torch.double)
+        run_test(torch.long)
+        run_test(torch.uint8)
+
     @staticmethod
     def _test_bincount(self, device):
         # negative input throws
diff --git a/torch/functional.py b/torch/functional.py
index 055141b..8c78b6e 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -389,7 +389,7 @@
     return tensor != tensor
 
 
-def unique(input, sorted=False, return_inverse=False):
+def unique(input, sorted=False, return_inverse=False, dim=None):
     r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
 
     Arguments:
@@ -431,11 +431,19 @@
                 [ 1,  2]])
 
     """
-    output, inverse_indices = torch._unique(
-        input,
-        sorted=sorted,
-        return_inverse=return_inverse,
-    )
+    if dim is not None:
+        output, inverse_indices = torch._unique_dim(
+            input,
+            dim,
+            sorted=sorted,
+            return_inverse=return_inverse
+        )
+    else:
+        output, inverse_indices = torch._unique(
+            input,
+            sorted=sorted,
+            return_inverse=return_inverse,
+        )
     if return_inverse:
         return output, inverse_indices
     else:
diff --git a/torch/tensor.py b/torch/tensor.py
index ed2f7f0..904d3a5 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -319,13 +319,22 @@
         """
         return self.clone().masked_fill_(mask, value)
 
-    def unique(self, sorted=False, return_inverse=False):
+    def unique(self, sorted=False, return_inverse=False, dim=None):
         r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
 
         See :func:`torch.unique`
         """
-        output, inverse_indices = self._unique(
-            sorted=sorted, return_inverse=return_inverse)
+        if dim is not None:
+            output, inverse_indices = self._unique_dim(
+                sorted=sorted,
+                return_inverse=return_inverse,
+                dim=dim
+            )
+        else:
+            output, inverse_indices = self._unique(
+                sorted=sorted,
+                return_inverse=return_inverse
+            )
         if return_inverse:
             return output, inverse_indices
         else: