Implement NumPy-like function torch.msort() (#48440)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.msort()` .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48440

Reviewed By: bdhirsh

Differential Revision: D25265753

Pulled By: mruberry

fbshipit-source-id: 7709ac5e5667e7541a3dc9048b9c9896b1a6dfa1
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 817ccb2..9295279 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -497,6 +497,7 @@
 _(aten, mse_loss) \
 _(aten, mse_loss_backward) \
 _(aten, mse_loss_forward) \
+_(aten, msort) \
 _(aten, multi_margin_loss) \
 _(aten, multi_margin_loss_backward) \
 _(aten, multi_margin_loss_forward) \
diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp
index c576832..e365d48 100644
--- a/aten/src/ATen/native/Sorting.cpp
+++ b/aten/src/ATen/native/Sorting.cpp
@@ -708,5 +708,15 @@
   return sort_out_cpu(values, indices, self, dim, descending);
 }
 
+Tensor& msort_out(Tensor& values, const Tensor& self) {
+  Tensor indices = at::empty({0}, self.options().dtype(kLong));
+  at::sort_out(values, indices, self, 0, false);
+  return values;
+}
+
+Tensor msort(const Tensor& self) {
+  return std::get<0>(at::sort(self, 0, false));
+}
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 46b7173..c7c1dc3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6779,6 +6779,16 @@
 - func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
   variants: method, function
 
+- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    Math: msort_out
+
+- func: msort(Tensor self) -> Tensor
+  use_c10_dispatcher: full
+  variants: method, function
+  dispatch:
+    Math: msort
+
 - func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
   use_c10_dispatcher: full
   variants: method, function
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 578f6a8..f523d60 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -463,6 +463,7 @@
    .. automethod:: mode
    .. automethod:: movedim
    .. automethod:: moveaxis
+   .. automethod:: msort
    .. automethod:: mul
    .. automethod:: mul_
    .. automethod:: multiply
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index ca288ff..98934e5 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -414,6 +414,7 @@
     not_equal
     sort
     topk
+    msort
 
 
 Spectral Ops
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index 494541a..ebd7c05 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -1,11 +1,12 @@
 import torch
+import numpy as np
 
 import random
 from torch._six import nan
 from itertools import product
 
 from torch.testing._internal.common_utils import \
-    (TestCase, run_tests)
+    (TestCase, run_tests, make_tensor)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
      skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
@@ -112,6 +113,33 @@
         self.assertIsOrdered('descending', x, res2val, res2ind,
                              'random with NaNs')
 
+    @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
+    def test_msort(self, device, dtype):
+        def test(shape):
+            tensor = make_tensor(shape, device, dtype, low=-9, high=9)
+            if tensor.size() != torch.Size([]):
+                expected = torch.from_numpy(np.msort(tensor.cpu().numpy()))
+            else:
+                expected = tensor  # numpy.msort() does not support empty shapes tensor
+
+            result = torch.msort(tensor)
+            self.assertEqual(result, expected)
+
+            out = torch.empty_like(result)
+            torch.msort(tensor, out=out)
+            self.assertEqual(out, expected)
+
+        shapes = (
+            [],
+            [0, ],
+            [20, ],
+            [1, 20],
+            [30, 30],
+            [10, 20, 30]
+        )
+        for shape in shapes:
+            test(shape)
+
     def test_topk(self, device):
         def topKViaSort(t, k, dim, dir):
             sorted, indices = t.sort(dim, dir)
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 9a4b90e..ef7d715 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -3382,6 +3382,13 @@
 See :func:`torch.sort`
 """)
 
+add_docstr_all('msort',
+               r"""
+msort() -> Tensor
+
+See :func:`torch.msort`
+""")
+
 add_docstr_all('argsort',
                r"""
 argsort(dim=-1, descending=False) -> LongTensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 022170e..5cc796d 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -7556,6 +7556,35 @@
             [3, 2, 1, 0]])
 """.format(**common_args))
 
+add_docstr(torch.msort,
+           r"""
+msort(input, *, out=None) -> Tensor
+
+Sorts the elements of the :attr:`input` tensor along its first dimension
+in ascending order by value.
+
+.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`.
+          See also :func:`torch.sort`.
+
+Args:
+    {input}
+
+Keyword args:
+    {out}
+
+Example::
+
+    >>> t = torch.randn(3, 4)
+    >>> t
+    tensor([[-0.1321,  0.4370, -1.2631, -1.1289],
+            [-2.0527, -1.1250,  0.2275,  0.3077],
+            [-0.0881, -0.1259, -0.5495,  1.0284]])
+    >>> torch.msort(t)
+    tensor([[-2.0527, -1.1250, -1.2631, -1.1289],
+            [-0.1321, -0.1259, -0.5495,  0.3077],
+            [-0.0881,  0.4370,  0.2275,  1.0284]])
+""".format(**common_args))
+
 add_docstr(torch.sparse_coo_tensor,
            r"""
 sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index 590bc80..f965496 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -525,6 +525,7 @@
         torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
         torch.movedim: lambda input, source, destination: -1,
         torch.moveaxis: lambda input, source, destination: -1,
+        torch.msort: lambda input, descending=False, out=None: -1,
         torch.mul: lambda input, other, out=None: -1,
         torch.multiply: lambda input, other, out=None: -1,
         torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 0b245ec..226dbdd 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1721,6 +1721,7 @@
         ('sort', (), NO_ARGS, 'scalar'),
         ('sort', (), (0,), 'dim_scalar'),
         ('sort', (), (0, True), 'dim_desc_scalar'),
+        ('msort', (S, M, S), NO_ARGS),
         ('topk', (S, M, S), (3,)),
         ('topk', (S, M, S), (3, 1), 'dim', (), [1]),
         ('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),