Refactor test for unique and unique_consecutive and fix some bugs (#31211)

Summary:
Tests for unique_dim will be refactored in a separate PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31211

Differential Revision: D19034968

Pulled By: ngimel

fbshipit-source-id: 855d326b37638b5944f11fbbce03394cf000daf9
diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp
index 30b1efe..2ac0403 100644
--- a/aten/src/ATen/native/Unique.cpp
+++ b/aten/src/ATen/native/Unique.cpp
@@ -81,41 +81,42 @@
   Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
   Tensor counts = at::empty({0}, self.options().dtype(kLong));
 
-  scalar_t *output_data = output.data_ptr<scalar_t>();
-  int64_t *inverse_data = nullptr;
-  int64_t *counts_data = nullptr;
-  if (numel > 0) {
-    *output_data = *input_data;
-  }
   if (return_inverse) {
     inverse_indices.resize_(input.sizes());
-    inverse_data = inverse_indices.data_ptr<int64_t>();
   }
-  if (return_counts) {
-    counts.resize_(input.sizes());
-    counts_data = counts.data_ptr<int64_t>();
-  }
-  scalar_t *p = output_data;
-  int64_t *q = counts_data;
-  int64_t last = 0;
-  for (int64_t i = 0; i < numel; i++) {
-    if (input_data[i] != *p) {
-      *(++p) = input_data[i];
-      if (return_counts) {
-        *(q++) = i - last;
-        last = i;
+
+  if (numel > 0) {
+    scalar_t *output_data = output.data_ptr<scalar_t>();
+    int64_t *inverse_data = inverse_indices.data_ptr<int64_t>();;
+    int64_t *counts_data = nullptr;
+    *output_data = *input_data;
+
+    if (return_counts) {
+      counts.resize_({numel});
+      counts_data = counts.data_ptr<int64_t>();
+    }
+    scalar_t *p = output_data;
+    int64_t *q = counts_data;
+    int64_t last = 0;
+    for (int64_t i = 0; i < numel; i++) {
+      if (input_data[i] != *p) {
+        *(++p) = input_data[i];
+        if (return_counts) {
+          *(q++) = i - last;
+          last = i;
+        }
+      }
+      if (return_inverse) {
+        inverse_data[i] = p - output_data;
       }
     }
-    if (return_inverse) {
-      inverse_data[i] = p - output_data;
+    int64_t output_size = p - output_data + 1;
+    if (return_counts) {
+      *q = numel - last;
+      counts.resize_({output_size});
     }
+    output.resize_({output_size});
   }
-  int64_t output_size = p - output_data + 1;
-  if (return_counts && numel > 0) {
-    *q = numel - last;
-    counts.resize_({output_size});
-  }
-  output.resize_({output_size});
 
   return std::make_tuple(output, inverse_indices, counts);
 }
@@ -158,7 +159,7 @@
     auto sizes = self.sizes().vec();
     // check how many zero dimensions exist
     auto num_zero_dims = std::count(sizes.begin(), sizes.end(), 0);
-    
+
     // tensor is not well formed as it has 0 sized dimensions
     if (self.size(dim) == 0){
       TORCH_CHECK(
@@ -171,10 +172,10 @@
 
       return std::make_tuple(output, inverse_indices, counts);
     }
-    
+
     TORCH_CHECK(num_zero_dims == 0,
     "There are 0 sized dimensions, and they aren't selected, so unique cannot be applied");
-  
+
   // reshape tensor as [dim, -1]
   Tensor input_flat = self.transpose(dim, 0);
   auto orig_sizes = input_flat.sizes().vec();
diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu
index eeee91a..9c99b4c 100644
--- a/aten/src/ATen/native/cuda/Unique.cu
+++ b/aten/src/ATen/native/cuda/Unique.cu
@@ -36,7 +36,7 @@
 
   // inverse indices
   Tensor inverse_indices;
-  if (!return_inverse) {
+  if (!return_inverse || num_inp == 0) {
     inverse_indices = at::empty({0}, options);
   } else {
     TORCH_CHECK(sorted_indices.defined(),
@@ -139,11 +139,11 @@
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
   auto policy = thrust::cuda::par(allocator).on(stream);
-  
+
   auto sizes = self.sizes().vec();
   // check how many zero dimensions exist
   auto num_zero_dims = std::count(sizes.begin(), sizes.end(), 0);
-  
+
   // tensor is not well formed as it has 0 sized dimensions
   if (self.size(dim) == 0){
     TORCH_CHECK(
@@ -221,7 +221,7 @@
 
 std::tuple<Tensor, Tensor>
 _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "unique", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
     Tensor output, inverse;
@@ -232,7 +232,7 @@
 
 std::tuple<Tensor, Tensor, Tensor>
 _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "unique", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
     return unique_cuda_template<scalar_t>(self, false, return_inverse, return_counts);
@@ -241,14 +241,14 @@
 
 std::tuple<Tensor, Tensor, Tensor>
 unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "unique_dim", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
     return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
   });
 }
 
 std::tuple<Tensor, Tensor, Tensor>
 unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "unique_dim", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
     return unique_dim_cuda_template<scalar_t>(self, dim, true, return_inverse, return_counts);
   });
 }
@@ -256,7 +256,7 @@
 std::tuple<Tensor, Tensor, Tensor>
 unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
   if (!dim.has_value()) {
-    return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "unique", [&] {
+    return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
       // The current CUDA implementation of unique always sort due to the
       // lack of hashtable implementation in thrust
       return unique_cuda_template<scalar_t>(self, true, return_inverse, return_counts);
diff --git a/test/test_torch.py b/test/test_torch.py
index f405969..dd9bdc5 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -12490,112 +12490,129 @@
                 result.is_contiguous(memory_format=torch.channels_last),
                 "result of the '{}' is not in channels_last format".format(inspect.getsource(fn).strip()))
 
-    def test_unique(self, device):
-        x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], device=device)
-        expected_unique = torch.tensor([1, 2, 3, 5, 8], device=device)
-        expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
-        expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
+    def _test_unique_scalar_empty(self, dtype, device, f):
+        # test scalar
+        x = torch.tensor(0, dtype=dtype, device=device)
+        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
+        expected_unique = torch.tensor([0], dtype=dtype, device=device)
+        expected_inverse = torch.tensor(0, device=device)
+        expected_counts = torch.tensor([1], device=device)
+        self.assertEqual(unique, expected_unique)
+        self.assertEqual(inverse, expected_inverse)
+        self.assertEqual(counts, expected_counts)
 
-        x_unique = torch.unique(x)
-        self.assertEqual(
-            expected_unique.tolist(), sorted(x_unique.tolist()))
+        # test zero sized tensor
+        x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
+        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
+        expected_unique = torch.tensor([], dtype=dtype, device=device)
+        expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
+        expected_counts = torch.tensor([], dtype=torch.long, device=device)
+        self.assertEqual(unique, expected_unique)
+        self.assertEqual(inverse, expected_inverse)
+        self.assertEqual(counts, expected_counts)
 
-        x_unique, x_inverse = x.unique(return_inverse=True)
-        self.assertEqual(
-            expected_unique.tolist(), sorted(x_unique.tolist()))
-        self.assertEqual(expected_inverse.numel(), x_inverse.numel())
+    def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
+        def ensure_tuple(x):
+            if torch.is_tensor(x):
+                return (x,)
+            return x
 
-        x_unique = x.unique(sorted=True)
-        self.assertEqual(expected_unique, x_unique)
+        for return_inverse in [True, False]:
+            for return_counts in [True, False]:
+                # test with expected
+                ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
+                self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
+                self.assertEqual(expected_unique, ret[0])
+                if return_inverse:
+                    self.assertEqual(expected_inverse, ret[1])
+                if return_counts:
+                    count_index = 1 + int(return_inverse)
+                    self.assertEqual(expected_counts, ret[count_index])
 
-        x_unique, x_counts = torch.unique(x, sorted=True, return_counts=True)
-        self.assertEqual(expected_counts, x_counts)
+                # tests per-element unique on a higher rank tensor.
+                y = x.view(additional_shape)
+                y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
+                self.assertEqual(expected_unique, y_unique)
+                self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
+                self.assertEqual(expected_counts, y_counts)
 
-        x_unique, x_inverse = torch.unique(
-            x, sorted=True, return_inverse=True)
-        self.assertEqual(expected_unique, x_unique)
-        self.assertEqual(expected_inverse, x_inverse)
+    @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16})
+    def test_unique(self, device, dtype):
+        if dtype is torch.half and self.device_type == 'cpu':
+            return  # CPU does not have half support
 
-        x_unique, x_inverse, x_counts = torch.unique(
-            x, sorted=True, return_inverse=True, return_counts=True)
-        self.assertEqual(expected_unique, x_unique)
-        self.assertEqual(expected_inverse, x_inverse)
-        self.assertEqual(expected_counts, x_counts)
+        def ensure_tuple(x):
+            if torch.is_tensor(x):
+                return (x,)
+            return x
 
-        # Tests per-element unique on a higher rank tensor.
-        y = x.view(2, 2, 2)
-        y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
-        self.assertEqual(expected_unique, y_unique)
-        self.assertEqual(expected_inverse.view(y.size()), y_inverse)
+        if dtype is torch.bool:
+            x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
+            expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
+            expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
+            expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
+        else:
+            x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
+            expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
+            expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
+            expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
 
-        y_unique, y_inverse, y_counts = torch.unique(
-            y, sorted=True, return_inverse=True, return_counts=True)
-        self.assertEqual(expected_unique, y_unique)
-        self.assertEqual(expected_inverse.view(y.size()), y_inverse)
-        self.assertEqual(expected_counts, y_counts)
+        # test sorted unique
+        fs = [
+            lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
+            lambda x, **kwargs: x.unique(sorted=True, **kwargs),
+        ]
+        for f in fs:
+            self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
+            self._test_unique_scalar_empty(dtype, device, f)
 
-        # Tests unique on other types.
-        int_unique, int_inverse, int_counts = torch.unique(
-            torch.tensor([2, 1, 2], dtype=torch.int, device=device),
-            sorted=True,
-            return_inverse=True,
-            return_counts=True
-        )
-        self.assertEqual(torch.tensor([1, 2], dtype=torch.int, device=device), int_unique)
-        self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse)
-        self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts)
+        # test unsorted unique
+        fs = [
+            lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
+            lambda x, **kwargs: x.unique(sorted=False, **kwargs)
+        ]
+        for f in fs:
+            self._test_unique_scalar_empty(dtype, device, f)
+            for return_inverse in [True, False]:
+                for return_counts in [True, False]:
+                    ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
+                    self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
+                    x_list = x.tolist()
+                    x_unique_list = ret[0].tolist()
+                    self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
+                    if return_inverse:
+                        x_inverse_list = ret[1].tolist()
+                        for i, j in enumerate(x_inverse_list):
+                            self.assertEqual(x_list[i], x_unique_list[j])
+                    if return_counts:
+                        count_index = 1 + int(return_inverse)
+                        x_counts_list = ret[count_index].tolist()
+                        for i, j in zip(x_unique_list, x_counts_list):
+                            count = 0
+                            for k in x_list:
+                                if k == i:
+                                    count += 1
+                            self.assertEqual(j, count)
 
-        double_unique, double_inverse, double_counts = torch.unique(
-            torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device),
-            sorted=True,
-            return_inverse=True,
-            return_counts=True
-        )
-        self.assertEqual(torch.tensor([1.5, 2., 2.1], dtype=torch.double, device=device), double_unique)
-        self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse)
-        self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts)
+    @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16})
+    def test_unique_consecutive(self, device, dtype):
+        if dtype is torch.half and self.device_type == 'cpu':
+            return  # CPU does not have half support
 
-        byte_unique, byte_inverse, byte_counts = torch.unique(
-            torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device),
-            sorted=True,
-            return_inverse=True,
-            return_counts=True
-        )
-        self.assertEqual(torch.tensor([7, 42, 128, 133], dtype=torch.uint8, device=device), byte_unique)
-        self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse)
-        self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts)
+        if dtype is torch.bool:
+            x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device)
+            expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
+            expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device)
+            expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device)
+        else:
+            x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device)
+            expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device)
+            expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
+            expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
 
-        bool_unique, bool_inverse, bool_counts = torch.unique(
-            torch.tensor([True, False, True, False], dtype=torch.bool, device=device),
-            sorted=True,
-            return_inverse=True,
-            return_counts=True
-        )
-        self.assertEqual(torch.tensor([False, True], dtype=torch.bool, device=device), bool_unique)
-        self.assertEqual(torch.tensor([1, 0, 1, 0], dtype=torch.long, device=device), bool_inverse)
-        self.assertEqual(torch.tensor([2, 2], dtype=torch.long, device=device), bool_counts)
-
-        # test consecutive version
-        z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device)
-        expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device)
-        expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
-        expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
-
-        z_unique = torch.unique_consecutive(z)
-        self.assertEqual(z_unique, expected_z_unique)
-
-        z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True)
-        self.assertEqual(z_unique, expected_z_unique)
-        self.assertEqual(z_inverse, expected_z_inverse)
-
-        z_unique, z_counts = torch.unique_consecutive(z, return_counts=True)
-        self.assertEqual(z_unique, expected_z_unique)
-        self.assertEqual(z_counts, expected_z_counts)
-
-        z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True)
-        self.assertEqual(z_unique, expected_z_unique)
-        self.assertEqual(z_inverse, expected_z_inverse)
-        self.assertEqual(z_counts, expected_z_counts)
+        for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]:
+            self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3))
+            self._test_unique_scalar_empty(dtype, device, f)
 
     @dtypesIfCUDA(torch.half, torch.float, torch.double)
     @dtypes(torch.float, torch.double)