Adds `dim` argument to `torch.unique` (#10423)
Summary:
Initial version of `unique` supporting a `dim` argument.
As discussed in [this issue](https://github.com/pytorch/pytorch/issues/9997) I added the `dim` argument to `torch.unique` with the same behavior like [numpy](https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.unique.html).
Since the implementation is based on `std/thrust::unique`, the `tensor` always needs to be sorted. The `sorted` argument in `torch.unique` does not have any function, as in the CUDA version of the plain `torch.unique`.
To check the performance and equal behavior between `torch.unique` and `np.unique`, I've used [this gist](https://gist.github.com/ptrblck/ac0dc862f4e1766f0e1036c252cdb105).
Currently we achieve the following timings for an input of `x = torch.randint(2, (1000, 1000))`:
(The values are calculated by taking the average of the times for both dimension)
| Device | PyTorch (return_inverse=False) | Numpy (return_inverse=False) | PyTorch (return_inverse=True) | Numpy (return_inverse=True) |
| --- | --- | --- | --- | --- |
| CPU | ~0.007331s | ~0.022452s | ~0.011139s | ~0.044800s |
| GPU | ~0.006154s | - | ~0.105373s | - |
Many thanks to colesbury for the awesome mentoring and the valuable advices on the general implementation and performance issues!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10423
Differential Revision: D9517289
Pulled By: soumith
fbshipit-source-id: a4754f805223589c2847c98b8e4e39d8c3ddb7b5
diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp
index d9bd94e..d5ff300 100644
--- a/aten/src/ATen/native/Unique.cpp
+++ b/aten/src/ATen/native/Unique.cpp
@@ -47,6 +47,82 @@
}
return std::make_tuple(output, inverse_indices);
}
+
+template<class ForwardIt>
+ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
+ std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
+ if (first == last) {
+ return last;
+ }
+ // save to calculate distance to iterators
+ ForwardIt begin = first;
+
+ // set first inverse index
+ inverse_indices_vec[indices[0]] = 0;
+
+ ForwardIt result = first;
+ while (++first != last) {
+ if (!at::equal(*result, *first) && ++result != first) {
+ *result = std::move(*first);
+ }
+ int64_t idx_result = std::distance(begin, result);
+ int64_t idx_first = std::distance(begin, first);
+ inverse_indices_vec[indices[idx_first]] = idx_result;
+ }
+
+ return ++result;
+ }
+
+template <typename scalar_t>
+std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
+ const Tensor& self,
+ const int64_t dim,
+ const bool return_inverse) {
+ // reshape tensor as [dim, -1]
+ Tensor input_flat = self.transpose(dim, 0);
+ auto orig_sizes = input_flat.sizes().vec();
+ input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
+
+ std::vector<int64_t> indices(input_flat.size(0));
+ std::iota(indices.begin(), indices.end(), 0);
+ int64_t numel = input_flat.size(1);
+ scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());
+
+ // sort indices using data
+ std::sort(indices.begin(), indices.end(),
+ [&](int64_t a, int64_t b) -> bool {
+ for (int64_t i = 0; i < numel; ++i) {
+ scalar_t lhs = input_flat_ptr[i + a * numel];
+ scalar_t rhs = input_flat_ptr[i + b * numel];
+ if (lhs < rhs) {
+ return true;
+ } else if (lhs > rhs) {
+ return false;
+ }
+ }
+ return false;
+ });
+
+ Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type());
+ for (int i = 0; i < indices.size(); ++i) {
+ input_sorted[i] = input_flat[indices[i]];
+ }
+
+ Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong));
+ std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
+ auto last = _unique_dim_cpu_impl(
+ input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
+ input_unbind.erase(last, input_unbind.end());
+
+ // reshape back
+ auto output = at::stack(input_unbind, 0);
+ auto new_sizes = std::vector<int64_t>(orig_sizes);
+ new_sizes[0] = -1;
+ output = output.view(new_sizes);
+ output = output.transpose(0, dim);
+
+ return std::make_tuple(output, inverse_indices);
+}
} // namespace
std::tuple<Tensor, Tensor>
@@ -56,5 +132,13 @@
});
}
+std::tuple<Tensor, Tensor>
+_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+ return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+ // The current implementation using `dim` always sorts due to unhashable tensors
+ return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
+ });
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu
index f2e13b4..c29337f 100644
--- a/aten/src/ATen/native/cuda/Unique.cu
+++ b/aten/src/ATen/native/cuda/Unique.cu
@@ -69,6 +69,92 @@
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
+
+template <typename scalar_t>
+ std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
+ const Tensor& self,
+ const int64_t dim,
+ const bool return_inverse) {
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+ auto policy = thrust::cuda::par(allocator).on(stream);
+
+ Tensor input_flat = self.transpose(dim, 0);
+ auto orig_sizes = input_flat.sizes().vec();
+ input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
+
+ scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
+
+ Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
+ int64_t* indices_ptr = indices.data<int64_t>();
+ int64_t numel = input_flat.size(1);
+
+ // sort indices using data
+ thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
+ [=] __device__ (int64_t a, int64_t b) -> bool {
+ for (int64_t i = 0; i < numel; ++i) {
+ scalar_t lhs = input_flat_ptr[i + a * numel];
+ scalar_t rhs = input_flat_ptr[i + b * numel];
+ if (lhs < rhs) {
+ return true;
+ } else if (lhs > rhs) {
+ return false;
+ }
+ }
+ return false;
+ });
+
+ Tensor input_sorted = input_flat.index_select(0, indices);
+
+ // get unique tensors
+ scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
+ Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
+ int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
+ auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
+ [=] __device__ (int64_t a, int64_t b) -> bool {
+ for (int64_t i = 0; i < numel; ++i) {
+ scalar_t lhs = input_sorted_ptr[i + a * numel];
+ scalar_t rhs = input_sorted_ptr[i + b * numel];
+ if (lhs != rhs) {
+ return false;
+ }
+ }
+ return true;
+ });
+ input_sorted_indices.resize_(last - input_sorted_indices_ptr);
+ Tensor output = input_sorted.index_select(0, input_sorted_indices);
+
+ // reshape back
+ auto new_sizes = std::vector<int64_t>(orig_sizes);
+ new_sizes[0] = -1;
+ output = output.view(new_sizes);
+ output = output.transpose(0, dim);
+
+ // calculate inverse indices
+ Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
+ if (return_inverse) {
+ int64_t size = self.size(dim);
+ inverse_indices.resize_(size);
+ Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
+ mask[0] = 1;
+ for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
+ if (!at::equal(input_sorted[i], input_sorted[i+1])) {
+ mask[i+1] = 1;
+ } else {
+ mask[i+1] = 0;
+ }
+ }
+
+ Tensor imask = at::cumsum(mask, 0) - 1;
+ for (int i = 0; i < indices.size(0); ++i) {
+ inverse_indices[indices[i]] = imask[i];
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return std::tuple<Tensor, Tensor>(output, inverse_indices);
+ }
} // namespace
#endif
@@ -86,5 +172,16 @@
#endif
}
+std::tuple<Tensor, Tensor>
+_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+ #ifndef __HIP_PLATFORM_HCC__
+ return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+ return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
+ });
+ #else
+ AT_ERROR("unique_dim_cuda: HIP not supported");
+ #endif
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 466fe6c..cb194cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1748,6 +1748,11 @@
CPU: _unique_cpu
CUDA: _unique_cuda
+- func: _unique_dim(Tensor self, int64_t dim, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor)
+ dispatch:
+ CPU: _unique_dim_cpu
+ CUDA: _unique_dim_cuda
+
- func: _unsafe_view(Tensor self, IntList size) -> Tensor
variants: function
diff --git a/test/test_torch.py b/test/test_torch.py
index 167a400..863f97f 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8485,6 +8485,67 @@
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
+ def test_unique_dim(self):
+ def run_test(dtype=torch.float):
+ x = torch.tensor([[[1., 1.],
+ [0., 1.],
+ [2., 1.],
+ [0., 1.]],
+ [[1., 1.],
+ [0., 1.],
+ [2., 1.],
+ [0., 1.]]], dtype=dtype)
+ expected_unique_dim0 = torch.tensor([[[1., 1.],
+ [0., 1.],
+ [2., 1.],
+ [0., 1.]]], dtype=dtype)
+ expected_inverse_dim0 = torch.tensor([0, 0])
+ expected_unique_dim1 = torch.tensor([[[0., 1.],
+ [1., 1.],
+ [2., 1.]],
+ [[0., 1.],
+ [1., 1.],
+ [2., 1.]]], dtype=dtype)
+ expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
+ expected_unique_dim2 = torch.tensor([[[1., 1.],
+ [0., 1.],
+ [2., 1.],
+ [0., 1.]],
+ [[1., 1.],
+ [0., 1.],
+ [2., 1.],
+ [0., 1.]]], dtype=dtype)
+ expected_inverse_dim2 = torch.tensor([0, 1])
+
+ # dim0
+ x_unique = torch.unique(x, dim=0)
+ self.assertEqual(expected_unique_dim0, x_unique)
+
+ x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
+ self.assertEqual(expected_unique_dim0, x_unique)
+ self.assertEqual(expected_inverse_dim0, x_inverse)
+
+ # dim1
+ x_unique = torch.unique(x, dim=1)
+ self.assertEqual(expected_unique_dim1, x_unique)
+
+ x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
+ self.assertEqual(expected_unique_dim1, x_unique)
+ self.assertEqual(expected_inverse_dim1, x_inverse)
+
+ # dim2
+ x_unique = torch.unique(x, dim=2)
+ self.assertEqual(expected_unique_dim2, x_unique)
+
+ x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
+ self.assertEqual(expected_unique_dim2, x_unique)
+ self.assertEqual(expected_inverse_dim2, x_inverse)
+
+ run_test(torch.float)
+ run_test(torch.double)
+ run_test(torch.long)
+ run_test(torch.uint8)
+
@staticmethod
def _test_bincount(self, device):
# negative input throws
diff --git a/torch/functional.py b/torch/functional.py
index 055141b..8c78b6e 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -389,7 +389,7 @@
return tensor != tensor
-def unique(input, sorted=False, return_inverse=False):
+def unique(input, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
Arguments:
@@ -431,11 +431,19 @@
[ 1, 2]])
"""
- output, inverse_indices = torch._unique(
- input,
- sorted=sorted,
- return_inverse=return_inverse,
- )
+ if dim is not None:
+ output, inverse_indices = torch._unique_dim(
+ input,
+ dim,
+ sorted=sorted,
+ return_inverse=return_inverse
+ )
+ else:
+ output, inverse_indices = torch._unique(
+ input,
+ sorted=sorted,
+ return_inverse=return_inverse,
+ )
if return_inverse:
return output, inverse_indices
else:
diff --git a/torch/tensor.py b/torch/tensor.py
index ed2f7f0..904d3a5 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -319,13 +319,22 @@
"""
return self.clone().masked_fill_(mask, value)
- def unique(self, sorted=False, return_inverse=False):
+ def unique(self, sorted=False, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
See :func:`torch.unique`
"""
- output, inverse_indices = self._unique(
- sorted=sorted, return_inverse=return_inverse)
+ if dim is not None:
+ output, inverse_indices = self._unique_dim(
+ sorted=sorted,
+ return_inverse=return_inverse,
+ dim=dim
+ )
+ else:
+ output, inverse_indices = self._unique(
+ sorted=sorted,
+ return_inverse=return_inverse
+ )
if return_inverse:
return output, inverse_indices
else: