Revert "Revert 107846 and 109695 (#111099)" (#113420)

The algorithm is taken from Numpy implementation at https://github.com/numpy/numpy/blob/main/numpy/lib/arraysetops.py#L323, it first do a sort on the input sequence and then  use a `mask` to record the unique element of each consecutive section.

Now we don't have parallel sort on 1-dimension float tensor, will have it enabled in next step. Parallel radix sort is used for 1-dimensional int tensor.

The following data is collected with script in the issue on Intel(R) Xeon(R) Gold 6248 CPU @ 2.5GHz with single sockets (20 cores):

#### before (dtype int64)
```
Numpy just sort: 0.4271528720855713 s
Numpy sort + indexes: 6.383563041687012 s
Torch just sort: 0.46924352645874023 s
Torch sort + indexes: 1.8140404224395752 s
```

#### after (dtype int64)
```
Torch just sort: 0.2540090084075928 s
Torch sort + indexes: 0.2766146659851074 s
```

#### before (float32)
```
Numpy just sort: 0.41129398345947266 s
Numpy sort + indexes: 6.422696590423584 s
Torch just sort: 9.109549283981323 s
Torch sort + indexes: 37.59021711349487 s
```

#### after (float32)
```
Torch just sort: 3.5369982719421387 s
Torch sort + indexes: 3.582240581512451 s
```

if we enabled parallel sort on 1-dimension float tensor, the performance is:
```
Torch just sort: 0.3212606906890869 s
Torch sort + indexes: 0.36211371421813965 s
```

Since i have fused the `inverse_indices` and `count` calculation in fused parallel loop (the algorithm is identical to NumPy's but with better optimization), they will take a small amount of additional time.

Use a reduction implementation for unique when dtype is bool on CPU.

This reverts commit 6dca81c054c1f7e378e956900265b085ca521e47 as `torch.sort` errors has been fixed in FBGEMM by https://github.com/pytorch/FBGEMM/commit/70c6e83c29f67278751abd0e28433c50743ccbe9.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113420
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp
index 22222f8..e000a98 100644
--- a/aten/src/ATen/native/Unique.cpp
+++ b/aten/src/ATen/native/Unique.cpp
@@ -3,6 +3,8 @@
 
 #include <ATen/core/Tensor.h>
 #include <ATen/Dispatch.h>
+#include <ATen/Parallel.h>
+#include <ATen/native/TensorIterator.h>
 #include <c10/util/irange.h>
 #include <c10/util/Load.h>
 
@@ -23,122 +25,242 @@
 #include <ATen/ops/zeros.h>
 #endif
 
-#include <tuple>
-#include <unordered_map>
-#include <unordered_set>
-
-namespace std {
-template <>
-struct hash<at::BFloat16> {
-  size_t operator()(const at::BFloat16& v) const noexcept {
-    return std::hash<uint16_t>()(v.x);
-  }
-};
-
-template <>
-struct hash<at::Half> {
-  size_t operator()(const at::Half& v) const noexcept {
-    return std::hash<uint16_t>()(v.x);
-  }
-};
-} // namespace std
-
 namespace at {
 namespace native{
 
 namespace {
 
-// Extract the unique elements from [begin, end) into a new Tensor
-template <typename scalar_t>
-Tensor unique_elements(const scalar_t* begin, const scalar_t* end,
-                       bool sorted, const TensorOptions &options) {
-  // Create unordered set of elements
-  auto set = std::unordered_set<scalar_t>(begin, end);
-
-  // Write the output tensor
-  Tensor output = at::empty({static_cast<int64_t>(set.size())}, options);
-  scalar_t *output_data = output.mutable_data_ptr<scalar_t>();
-  std::copy(set.begin(), set.end(), output_data);
-  if (sorted) {
-    std::sort(output_data, output_data + set.size());
-  }
-  return output;
-}
-
-// Specialization for boolean inputs, since we can't construct a set
-// directly from an array of bool as it won't handle invalid byte values.
-// See NOTE [Loading boolean values]
-Tensor unique_elements(const bool* begin, const bool* end,
-                       bool /*sorted*/, const TensorOptions &options) {
-  // Instead of a set, track whether a value has been seen
-  std::array<bool, 2> seen;
-  seen.fill(false);
-
-  for (; begin != end; ++begin) {
-    seen[c10::load(begin)] = true;
-    if (seen[false] && seen[true]) {
-      break;
-    }
-  }
-
-  // Write the output tensor
-  int64_t num_elem = seen[false] + seen[true];
-  Tensor output = at::empty({num_elem}, options);
-  bool *output_data = output.mutable_data_ptr<bool>();
-
-  if (seen[false]) {
-    *output_data++ = false;
-  }
-  if (seen[true]) {
-    *output_data++ = true;
-  }
-  return output;
-}
-
-template <typename scalar_t>
-std::tuple<Tensor, Tensor, Tensor> unique_cpu_template(
+// This unique implementation when dtype is bool is mapped
+// from UniqueCub.cu which uses a reduction to find the number of
+// true values.
+std::tuple<Tensor, Tensor, Tensor> unique_cpu_bool_template(
     const Tensor& self,
-    const bool sorted,
     const bool return_inverse,
     const bool return_counts) {
   const Tensor& input = self.contiguous();
-  const scalar_t* input_data = input.data_ptr<scalar_t>();
+  bool* input_data = input.data_ptr<bool>();
+
   int64_t numel = input.numel();
+  Tensor output = at::empty({0}, self.options());
   Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
   Tensor counts = at::empty({0}, self.options().dtype(kLong));
-  Tensor output = unique_elements(input_data, input_data + numel,
-                                  sorted, input.options());
-  const scalar_t *output_data = output.data_ptr<scalar_t>();
 
-  if (return_inverse || return_counts) {
+  if (numel == 0) {
+    if (return_inverse) {
+      inverse_indices.resize_(input.sizes());
+    }
+    return std::make_tuple(output, inverse_indices, counts);
+  }
+
+  int num_threads = at::get_num_threads();
+  std::vector<int64_t> num_true_thread(num_threads, 0);
+
+  const int64_t grain_size = at::internal::GRAIN_SIZE;
+  at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
+    int tid = at::get_thread_num();
+    for (const auto i : c10::irange(begin, end)) {
+      const bool value = c10::load(&input_data[i]);
+      if (value) {
+        num_true_thread[tid]++;
+      }
+    }
+  });
+
+  int64_t num_true = std::accumulate(num_true_thread.begin(), num_true_thread.end(), 0);
+  int64_t num_false = numel - num_true;
+  int num_out = ((num_true > 0) + (num_false > 0));
+
+  constexpr int false_idx = 0;
+  const int true_idx = num_false > 0;
+
+  output.resize_({num_out});
+  if (return_counts) {
+    counts.resize_({num_out});
+  }
+  bool* output_data = output.data_ptr<bool>();
+  int64_t* counts_data = return_counts ? counts.data_ptr<int64_t>() : nullptr;
+
+  // write output and counts
+  if (num_false > 0) {
+    output_data[false_idx] = false;
+    if (return_counts) {
+      counts_data[false_idx] = num_false;
+    }
+  }
+  if (num_true > 0) {
+    output_data[true_idx] = true;
+    if (return_counts) {
+      counts_data[true_idx] = num_true;
+    }
+  }
+
+  if (return_inverse) {
     inverse_indices.resize_(input.sizes());
     int64_t* inverse_indices_data = inverse_indices.data_ptr<int64_t>();
-    std::unordered_map<scalar_t, int64_t> inverse_map;
-    inverse_map.reserve(output.numel());
-    for (const auto i : c10::irange(output.numel())) {
-      inverse_map[output_data[i]] = i;
-    }
-    for (const auto i : c10::irange(numel)) {
-      const auto val = c10::load(&input_data[i]);
-      inverse_indices_data[i] = inverse_map[val];
-    }
-    if (return_counts) {
-      std::unordered_map<scalar_t, int64_t> counts_map;
-      counts_map.reserve(output.numel());
-      for (const auto i : c10::irange(output.numel())) {
-        counts_map[output_data[i]] = 0;
+    at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
+      for (const auto i : c10::irange(begin, end)) {
+        const bool value = c10::load(&input_data[i]);
+        inverse_indices_data[i] = value ? true_idx : false_idx;
       }
-      for (const auto i : c10::irange(numel)) {
-        const auto val = c10::load(&input_data[i]);
-        counts_map[val] += 1;
-      }
-      counts.resize_(output.sizes());
-      counts.fill_(0);
-      int64_t *counts_data = counts.data_ptr<int64_t>();
-      for (const auto i : c10::irange(output.numel())) {
-        counts_data[i] = counts_map[output_data[i]];
+    });
+  }
+  return std::make_tuple(output, inverse_indices, counts);
+}
+
+// check whether the element on index i is `unique`,
+// in the sorted sequence, the 1st element is always true.
+//
+// NaN is propagated to the rear in a sorted sequence,
+// consider a sorted sequence of
+//   {1.0, 1.0, 2.0, 2.0, NaN, NaN, NaN}
+//
+// a. `equal_nan` == true will give:
+//   {T,   F,   T,   F,   T,   F,   F  }
+//
+// b. `equal_nan` == false will give:
+//   {T,   F,   T,   F,   T,   T,   T  }
+//
+template <typename scalar_t, bool equal_nan>
+struct IsUnique {};
+
+template <typename scalar_t>
+struct IsUnique<scalar_t, false> {
+  inline bool operator() (scalar_t* data_ptr, int64_t i) {
+    if (i == 0) { return true; }
+    return c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]);
+  }
+};
+
+template <typename scalar_t>
+struct IsUnique<scalar_t, true> {
+  inline bool operator() (scalar_t* data_ptr, int64_t i) {
+    if (i == 0) { return true; }
+    return (c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]))
+        && !(_isnan(data_ptr[i]) && _isnan(data_ptr[i - 1]));
+  }
+};
+
+// NB: Unique implementation using sort
+//
+// The whole algo is taken from NumPy at numpy/lib/arraysetops.py
+// which firstly do sort on the input sequence and then convert
+// it to consecutive unique.
+//
+// Also improvement has been made upon the NumPy version: parallel
+// `inverse_indices` and `counts` computation in a fused loop,
+// which made this part almost a free launch.
+//
+// This kernel also implements a `equal_nan` flag which has same
+// function as NumPy's unique. Currently this is always disabled.
+//
+// TODO: add `bool` specialization, use similar approach as UniqueCub
+//
+template <typename scalar_t, typename CompareOp>
+std::tuple<Tensor, Tensor, Tensor> unique_cpu_sorted_template(
+    const Tensor& self,
+    const bool return_inverse,
+    const bool return_counts,
+    CompareOp is_unique) {
+  const Tensor& input = self.contiguous();
+
+  int64_t numel = input.numel();
+  Tensor output = at::empty({0}, self.options());
+  Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
+  Tensor counts = at::empty({0}, self.options().dtype(kLong));
+
+  if (numel == 0) {
+    if (return_inverse) {
+      inverse_indices.resize_(input.sizes());
+    }
+    return std::make_tuple(output, inverse_indices, counts);
+  }
+
+  // index of first unique in each consecutive section
+  // this is used to compute counts for parallelization purpose
+  Tensor unique_index = at::empty({0}, self.options().dtype(kLong));
+
+  // original behavior with unique on scalar tensor
+  // is to return a output size of ([1]), `flatten` here will do the job
+  auto input_flattened = input.flatten();
+
+  Tensor input_sorted, indices;
+  std::tie(input_sorted, indices) = input_flattened.sort();
+
+  scalar_t* input_sorted_data = input_sorted.data_ptr<scalar_t>();
+  int64_t* indices_data = indices.data_ptr<int64_t>();
+
+  int num_threads = at::get_num_threads();
+  std::vector<int64_t> unique_count_thread(num_threads, 0);
+  std::vector<int64_t> offset_thread(num_threads, 0);
+
+  const int64_t grain_size = at::internal::GRAIN_SIZE;
+
+  // calculate unique count from each thread
+  at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
+    int tid = at::get_thread_num();
+    for (const auto i : c10::irange(begin, end)) {
+      if (is_unique(input_sorted_data, i)) {
+        unique_count_thread[tid]++;
       }
     }
+  });
+
+  // calculate thread offset in output and
+  // `unique_count` records total count of uniques at last
+  int64_t unique_count = 0;
+  for (const auto t : c10::irange(num_threads)) {
+    offset_thread[t] = unique_count;
+    unique_count += unique_count_thread[t];
+  }
+
+  output.resize_({unique_count});
+  scalar_t* output_data = output.data_ptr<scalar_t>();
+
+  int64_t* inverse_indices_data = nullptr;
+  if (return_inverse) {
+    inverse_indices.resize_(input.sizes());
+    inverse_indices_data = inverse_indices.data_ptr<int64_t>();
+  }
+
+  int64_t* counts_data = nullptr;
+  int64_t* unique_index_data = nullptr;
+  if (return_counts) {
+    counts.resize_({unique_count});
+    counts_data = counts.data_ptr<int64_t>();
+
+    unique_index.resize_({unique_count + 1});
+    unique_index_data = unique_index.data_ptr<int64_t>();
+    unique_index_data[unique_count] = numel;
+  }
+
+  at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
+    int tid = at::get_thread_num();
+    int64_t offset = offset_thread[tid];
+
+    for (const auto i : c10::irange(begin, end)) {
+      if (is_unique(input_sorted_data, i)) {
+        output_data[offset] = c10::load(&input_sorted_data[i]);
+        if (return_counts) {
+          unique_index_data[offset] = i;
+        }
+        offset++;
+      }
+
+      if (return_inverse) {
+        int64_t inverse_index = offset - 1;
+        int64_t perm = indices_data[i];
+        inverse_indices_data[perm] = inverse_index;
+      }
+    }
+  });
+
+  if (return_counts) {
+    // do diff to get count
+    at::parallel_for(0, unique_count, grain_size, [&](int64_t begin, int64_t end) {
+      for (const auto i : c10::irange(begin, end)) {
+        counts_data[i] = unique_index_data[i + 1] - unique_index_data[i];
+      }
+    });
   }
   return std::make_tuple(output, inverse_indices, counts);
 }
@@ -318,20 +440,34 @@
 
 } // namespace
 
-
 std::tuple<Tensor, Tensor>
 _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kBool, kHalf, self.scalar_type(), "unique", [&] {
+  if (self.scalar_type() == kBool) {
     Tensor output, inverse;
-    std::tie(output, inverse, std::ignore) = unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
+    std::tie(output, inverse, std::ignore) = unique_cpu_bool_template(
+        self, return_inverse, /* return_counts */false);
+    return std::make_tuple(output, inverse);
+  }
+  return AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "unique", [&] {
+    Tensor output, inverse;
+    // The current CPU implementation of unique always sort due to
+    // this is faster than hash table
+    std::tie(output, inverse, std::ignore) = unique_cpu_sorted_template<scalar_t>(
+        self, return_inverse, /* return_counts */false, IsUnique<scalar_t, /* equal_nan */false>());
     return std::make_tuple(output, inverse);
   });
 }
 
 std::tuple<Tensor, Tensor, Tensor>
 _unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kBool, kHalf, self.scalar_type(), "unique", [&] {
-    return unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
+  if (self.scalar_type() == kBool) {
+    return unique_cpu_bool_template(self, return_inverse, return_counts);
+  }
+  return AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "unique", [&] {
+    // The current CPU implementation of unique always sort due to
+    // this is faster than hash table
+    return unique_cpu_sorted_template<scalar_t>(
+        self, return_inverse, return_counts, IsUnique<scalar_t, /* equal_nan */ false>());
   });
 }