[numpy] Add `torch.moveaxis` (#48581)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/38349 #36048 https://github.com/pytorch/pytorch/pull/41480#issuecomment-734398262

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48581

Reviewed By: bdhirsh

Differential Revision: D25276307

Pulled By: mruberry

fbshipit-source-id: 3e3e4df1343c5ce5b71457badc43f08c419ec5c3
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 56f1f0f..e3a855b 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -546,7 +546,6 @@
 _(aten, pdist) \
 _(aten, cdist) \
 _(aten, permute) \
-_(aten, movedim) \
 _(aten, pin_memory) \
 _(aten, pinverse) \
 _(aten, pixel_shuffle) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index eede5f4..72cf483 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -284,6 +284,8 @@
   _(aten, swapaxes_)                 \
   _(aten, swapdims)                  \
   _(aten, swapdims_)                 \
+  _(aten, movedim)                   \
+  _(aten, moveaxis)                  \
   FORALL_ATEN_BASE_SYMBOLS(_)        \
   _(onnx, Add)                       \
   _(onnx, Concat)                    \
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 11889da..eda688a 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -2002,6 +2002,14 @@
   return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
 }
 
+Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
+  return at::movedim(self, src, dst);
+}
+
+Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) {
+  return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
+}
+
 Tensor swapaxes(const Tensor& self, int64_t axis0, int64_t axis1) {
   return self.transpose(axis0, axis1);
 }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3cc6f3a..8f237f9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3057,6 +3057,15 @@
   use_c10_dispatcher: full
   variants: function, method
 
+# moveaxis, alias for movedim
+- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+  use_c10_dispatcher: full
+  variants: function, method
+
 # Only exposed from C++ -- in Python,
 # we expose it as an attribute `T`, not a function.
 #
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index d7b0af7..3f12004 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -461,6 +461,7 @@
       :noindex:
    .. automethod:: mode
    .. automethod:: movedim
+   .. automethod:: moveaxis
    .. automethod:: mul
    .. automethod:: mul_
    .. automethod:: multiply
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index b16f14d..d7c80de 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -92,6 +92,7 @@
     index_select
     masked_select
     movedim
+    moveaxis
     narrow
     nonzero
     reshape
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 5c8acd7..125fa7a 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -7240,15 +7240,16 @@
         self.assertEqual(c.grad.stride(), (2, 1))
 
     def test_movedim(self, device):
-        x = torch.randn(4, 3, 2, 1, dtype=torch.double, device=device, requires_grad=True)
+        for fn in [torch.movedim, torch.moveaxis]:
+            x = torch.randn(4, 3, 2, 1, dtype=torch.double, device=device, requires_grad=True)
 
-        # Positive axis
-        gradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
-        gradgradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
+            # Positive axis
+            gradcheck(lambda x: fn(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
+            gradgradcheck(lambda x: fn(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
 
-        # Negative axis
-        gradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
-        gradgradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
+            # Negative axis
+            gradcheck(lambda x: fn(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
+            gradgradcheck(lambda x: fn(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
 
     def _test_atleast(self, device, torch_fn):
         # 0-dim
diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py
index 7ba2f48..8a9b7c3 100644
--- a/test/test_op_aliases.py
+++ b/test/test_op_aliases.py
@@ -161,6 +161,8 @@
               lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
     AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack,
               lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))),
+    AliasInfo('moveaxis', torch.moveaxis, 'movedim', torch.movedim,
+              lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
 )
 
 # Placeholder test class for validating that aliases are correctly
diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py
index 3a7afa7..4332150 100644
--- a/test/test_shape_ops.py
+++ b/test/test_shape_ops.py
@@ -86,95 +86,97 @@
         shape = self._rand_shape(4, min_size=5, max_size=10)
         x = _generate_input(shape, dtype, device, False)
 
-        # Invalid `source` and `destination` dimension
-        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
-            torch.movedim(x, 5, 0)
+        for fn in [torch.movedim, torch.moveaxis]:
+            # Invalid `source` and `destination` dimension
+            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
+                fn(x, 5, 0)
 
-        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
-            torch.movedim(x, 0, 5)
+            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
+                fn(x, 0, 5)
 
-        # Mismatch in size of `source` and `destination`
-        with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
-            torch.movedim(x, (1, 0), (0, ))
+            # Mismatch in size of `source` and `destination`
+            with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
+                fn(x, (1, 0), (0, ))
 
-        with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
-            torch.movedim(x, (0, 0), (0, 1))
+            with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
+                fn(x, (0, 0), (0, 1))
 
-        with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
-            torch.movedim(x, (0, 1, 0), (0, 1, 2))
+            with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
+                fn(x, (0, 1, 0), (0, 1, 2))
 
-        with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
-            torch.movedim(x, (0, 1), (1, 1))
+            with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
+                fn(x, (0, 1), (1, 1))
 
-        with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
-            torch.movedim(x, (0, 1, 2), (1, 0, 1))
+            with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
+                fn(x, (0, 1, 2), (1, 0, 1))
 
     @dtypes(torch.int64, torch.float, torch.complex128)
     def test_movedim(self, device, dtype):
-        for nd in range(5):
-            shape = self._rand_shape(nd, min_size=5, max_size=10)
-            x = _generate_input(shape, dtype, device, with_extremal=False)
-            for random_negative in [True, False]:
-                for src_dim, dst_dim in permutations(range(nd), r=2):
-                    random_prob = random.random()
+        for fn in [torch.moveaxis, torch.movedim]:
+            for nd in range(5):
+                shape = self._rand_shape(nd, min_size=5, max_size=10)
+                x = _generate_input(shape, dtype, device, with_extremal=False)
+                for random_negative in [True, False]:
+                    for src_dim, dst_dim in permutations(range(nd), r=2):
+                        random_prob = random.random()
 
-                    if random_negative and random_prob > 0.66:
-                        src_dim = src_dim - nd
-                    elif random_negative and random_prob > 0.33:
-                        dst_dim = dst_dim - nd
-                    elif random_negative:
-                        src_dim = src_dim - nd
-                        dst_dim = dst_dim - nd
+                        if random_negative and random_prob > 0.66:
+                            src_dim = src_dim - nd
+                        elif random_negative and random_prob > 0.33:
+                            dst_dim = dst_dim - nd
+                        elif random_negative:
+                            src_dim = src_dim - nd
+                            dst_dim = dst_dim - nd
 
-                    # Integer `source` and `destination`
-                    torch_fn = partial(torch.movedim, source=src_dim, destination=dst_dim)
-                    np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
-                    self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+                        # Integer `source` and `destination`
+                        torch_fn = partial(fn, source=src_dim, destination=dst_dim)
+                        np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
+                        self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
 
-                if nd == 0:
-                    continue
+                    if nd == 0:
+                        continue
 
-                def make_index_negative(sequence, idx):
-                    sequence = list(sequence)
-                    sequence[random_idx] = sequence[random_idx] - nd
-                    return tuple(src_sequence)
+                    def make_index_negative(sequence, idx):
+                        sequence = list(sequence)
+                        sequence[random_idx] = sequence[random_idx] - nd
+                        return tuple(src_sequence)
 
-                for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
-                    # Sequence `source` and `destination`
-                    dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
+                    for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
+                        # Sequence `source` and `destination`
+                        dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
 
-                    # Randomly change a dim to a negative dim representation of itself.
-                    random_prob = random.random()
-                    if random_negative and random_prob > 0.66:
-                        random_idx = random.randint(0, len(src_sequence) - 1)
-                        src_sequence = make_index_negative(src_sequence, random_idx)
-                    elif random_negative and random_prob > 0.33:
-                        random_idx = random.randint(0, len(src_sequence) - 1)
-                        dst_sequence = make_index_negative(dst_sequence, random_idx)
-                    elif random_negative:
-                        random_idx = random.randint(0, len(src_sequence) - 1)
-                        dst_sequence = make_index_negative(dst_sequence, random_idx)
-                        random_idx = random.randint(0, len(src_sequence) - 1)
-                        src_sequence = make_index_negative(src_sequence, random_idx)
+                        # Randomly change a dim to a negative dim representation of itself.
+                        random_prob = random.random()
+                        if random_negative and random_prob > 0.66:
+                            random_idx = random.randint(0, len(src_sequence) - 1)
+                            src_sequence = make_index_negative(src_sequence, random_idx)
+                        elif random_negative and random_prob > 0.33:
+                            random_idx = random.randint(0, len(src_sequence) - 1)
+                            dst_sequence = make_index_negative(dst_sequence, random_idx)
+                        elif random_negative:
+                            random_idx = random.randint(0, len(src_sequence) - 1)
+                            dst_sequence = make_index_negative(dst_sequence, random_idx)
+                            random_idx = random.randint(0, len(src_sequence) - 1)
+                            src_sequence = make_index_negative(src_sequence, random_idx)
 
-                    torch_fn = partial(torch.movedim, source=src_sequence, destination=dst_sequence)
-                    np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
-                    self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+                        torch_fn = partial(fn, source=src_sequence, destination=dst_sequence)
+                        np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
+                        self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
 
-        # Move dim to same position
-        x = torch.randn(2, 3, 5, 7, 11)
-        torch_fn = partial(torch.movedim, source=(0, 1), destination=(0, 1))
-        np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
-        self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+            # Move dim to same position
+            x = torch.randn(2, 3, 5, 7, 11)
+            torch_fn = partial(fn, source=(0, 1), destination=(0, 1))
+            np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
+            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
 
-        torch_fn = partial(torch.movedim, source=1, destination=1)
-        np_fn = partial(np.moveaxis, source=1, destination=1)
-        self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+            torch_fn = partial(fn, source=1, destination=1)
+            np_fn = partial(np.moveaxis, source=1, destination=1)
+            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
 
-        # Empty Sequence
-        torch_fn = partial(torch.movedim, source=(), destination=())
-        np_fn = partial(np.moveaxis, source=(), destination=())
-        self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+            # Empty Sequence
+            torch_fn = partial(fn, source=(), destination=())
+            np_fn = partial(np.moveaxis, source=(), destination=())
+            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
 
     @dtypes(torch.float, torch.bool)
     def test_diag(self, device, dtype):
diff --git a/test/test_view_ops.py b/test/test_view_ops.py
index d4e59a3..15f1bcd 100644
--- a/test/test_view_ops.py
+++ b/test/test_view_ops.py
@@ -535,11 +535,12 @@
                 out[idx_1, idx_2] = random.random()
                 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
 
-        op = partial(torch.movedim, source=(0, 1), destination=(1, 0))
-        run_test(device, op)
+        for fn in [torch.movedim, torch.moveaxis]:
+            op = partial(fn, source=(0, 1), destination=(1, 0))
+            run_test(device, op)
 
-        op = partial(torch.movedim, source=0, destination=1)
-        run_test(device, op)
+            op = partial(fn, source=0, destination=1)
+            run_test(device, op)
 
 class TestOldViewOps(TestCase):
     def test_ravel(self, device):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 99f504a..87bbf38 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2403,6 +2403,12 @@
 See :func:`torch.movedim`
 """)
 
+add_docstr_all('moveaxis', r"""
+moveaxis(source, destination) -> Tensor
+
+See :func:`torch.moveaxis`
+""")
+
 add_docstr_all('mul', r"""
 mul(value) -> Tensor
 
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index dd4be74..7852f3b 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5749,6 +5749,43 @@
             [[-0.8437,  0.1727, -0.1398]]])
 """.format(**common_args))
 
+add_docstr(torch.moveaxis, r"""
+moveaxis(input, source, destination) -> Tensor
+
+Alias for :func:`torch.movedim`.
+
+This function is equivalent to NumPy's moveaxis function.
+
+Examples::
+
+    >>> t = torch.randn(3,2,1)
+    >>> t
+    tensor([[[-0.3362],
+            [-0.8437]],
+
+            [[-0.9627],
+            [ 0.1727]],
+
+            [[ 0.5173],
+            [-0.1398]]])
+    >>> torch.moveaxis(t, 1, 0).shape
+    torch.Size([2, 3, 1])
+    >>> torch.moveaxis(t, 1, 0)
+    tensor([[[-0.3362],
+            [-0.9627],
+            [ 0.5173]],
+
+            [[-0.8437],
+            [ 0.1727],
+            [-0.1398]]])
+    >>> torch.moveaxis(t, (1, 2), (0, 1)).shape
+    torch.Size([2, 1, 3])
+    >>> torch.moveaxis(t, (1, 2), (0, 1))
+    tensor([[[-0.3362, -0.9627,  0.5173]],
+
+            [[-0.8437,  0.1727, -0.1398]]])
+""".format(**common_args))
+
 add_docstr(torch.swapdims, r"""
 swapdims(input, dim0, dim1) -> Tensor
 
diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp
index 2f463f7..e032eeb 100644
--- a/torch/csrc/jit/passes/normalize_ops.cpp
+++ b/torch/csrc/jit/passes/normalize_ops.cpp
@@ -103,6 +103,7 @@
       {aten::swapdims_, aten::transpose_},
       {aten::swapaxes, aten::transpose},
       {aten::swapaxes_, aten::transpose_},
+      {aten::moveaxis, aten::movedim},
   };
   return alias_map;
 }
diff --git a/torch/overrides.py b/torch/overrides.py
index f6a4937..f7b9bed 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -522,6 +522,7 @@
         torch.mm: lambda input, mat2, out=None: -1,
         torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
         torch.movedim: lambda input, source, destination: -1,
+        torch.moveaxis: lambda input, source, destination: -1,
         torch.mul: lambda input, other, out=None: -1,
         torch.multiply: lambda input, other, out=None: -1,
         torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,