Implement NumPy-like function torch.msort() (#48440)
Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.msort()` .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48440
Reviewed By: bdhirsh
Differential Revision: D25265753
Pulled By: mruberry
fbshipit-source-id: 7709ac5e5667e7541a3dc9048b9c9896b1a6dfa1
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 817ccb2..9295279 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -497,6 +497,7 @@
_(aten, mse_loss) \
_(aten, mse_loss_backward) \
_(aten, mse_loss_forward) \
+_(aten, msort) \
_(aten, multi_margin_loss) \
_(aten, multi_margin_loss_backward) \
_(aten, multi_margin_loss_forward) \
diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp
index c576832..e365d48 100644
--- a/aten/src/ATen/native/Sorting.cpp
+++ b/aten/src/ATen/native/Sorting.cpp
@@ -708,5 +708,15 @@
return sort_out_cpu(values, indices, self, dim, descending);
}
+Tensor& msort_out(Tensor& values, const Tensor& self) {
+ Tensor indices = at::empty({0}, self.options().dtype(kLong));
+ at::sort_out(values, indices, self, 0, false);
+ return values;
+}
+
+Tensor msort(const Tensor& self) {
+ return std::get<0>(at::sort(self, 0, false));
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 46b7173..c7c1dc3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6779,6 +6779,16 @@
- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
variants: method, function
+- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ Math: msort_out
+
+- func: msort(Tensor self) -> Tensor
+ use_c10_dispatcher: full
+ variants: method, function
+ dispatch:
+ Math: msort
+
- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
use_c10_dispatcher: full
variants: method, function
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 578f6a8..f523d60 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -463,6 +463,7 @@
.. automethod:: mode
.. automethod:: movedim
.. automethod:: moveaxis
+ .. automethod:: msort
.. automethod:: mul
.. automethod:: mul_
.. automethod:: multiply
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index ca288ff..98934e5 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -414,6 +414,7 @@
not_equal
sort
topk
+ msort
Spectral Ops
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index 494541a..ebd7c05 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -1,11 +1,12 @@
import torch
+import numpy as np
import random
from torch._six import nan
from itertools import product
from torch.testing._internal.common_utils import \
- (TestCase, run_tests)
+ (TestCase, run_tests, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
@@ -112,6 +113,33 @@
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')
+ @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
+ def test_msort(self, device, dtype):
+ def test(shape):
+ tensor = make_tensor(shape, device, dtype, low=-9, high=9)
+ if tensor.size() != torch.Size([]):
+ expected = torch.from_numpy(np.msort(tensor.cpu().numpy()))
+ else:
+ expected = tensor # numpy.msort() does not support empty shapes tensor
+
+ result = torch.msort(tensor)
+ self.assertEqual(result, expected)
+
+ out = torch.empty_like(result)
+ torch.msort(tensor, out=out)
+ self.assertEqual(out, expected)
+
+ shapes = (
+ [],
+ [0, ],
+ [20, ],
+ [1, 20],
+ [30, 30],
+ [10, 20, 30]
+ )
+ for shape in shapes:
+ test(shape)
+
def test_topk(self, device):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 9a4b90e..ef7d715 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -3382,6 +3382,13 @@
See :func:`torch.sort`
""")
+add_docstr_all('msort',
+ r"""
+msort() -> Tensor
+
+See :func:`torch.msort`
+""")
+
add_docstr_all('argsort',
r"""
argsort(dim=-1, descending=False) -> LongTensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 022170e..5cc796d 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -7556,6 +7556,35 @@
[3, 2, 1, 0]])
""".format(**common_args))
+add_docstr(torch.msort,
+ r"""
+msort(input, *, out=None) -> Tensor
+
+Sorts the elements of the :attr:`input` tensor along its first dimension
+in ascending order by value.
+
+.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`.
+ See also :func:`torch.sort`.
+
+Args:
+ {input}
+
+Keyword args:
+ {out}
+
+Example::
+
+ >>> t = torch.randn(3, 4)
+ >>> t
+ tensor([[-0.1321, 0.4370, -1.2631, -1.1289],
+ [-2.0527, -1.1250, 0.2275, 0.3077],
+ [-0.0881, -0.1259, -0.5495, 1.0284]])
+ >>> torch.msort(t)
+ tensor([[-2.0527, -1.1250, -1.2631, -1.1289],
+ [-0.1321, -0.1259, -0.5495, 0.3077],
+ [-0.0881, 0.4370, 0.2275, 1.0284]])
+""".format(**common_args))
+
add_docstr(torch.sparse_coo_tensor,
r"""
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index 590bc80..f965496 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -525,6 +525,7 @@
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
torch.moveaxis: lambda input, source, destination: -1,
+ torch.msort: lambda input, descending=False, out=None: -1,
torch.mul: lambda input, other, out=None: -1,
torch.multiply: lambda input, other, out=None: -1,
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 0b245ec..226dbdd 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1721,6 +1721,7 @@
('sort', (), NO_ARGS, 'scalar'),
('sort', (), (0,), 'dim_scalar'),
('sort', (), (0, True), 'dim_desc_scalar'),
+ ('msort', (S, M, S), NO_ARGS),
('topk', (S, M, S), (3,)),
('topk', (S, M, S), (3, 1), 'dim', (), [1]),
('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),