[numpy] Add `torch.moveaxis` (#48581)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/38349 #36048 https://github.com/pytorch/pytorch/pull/41480#issuecomment-734398262
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48581
Reviewed By: bdhirsh
Differential Revision: D25276307
Pulled By: mruberry
fbshipit-source-id: 3e3e4df1343c5ce5b71457badc43f08c419ec5c3
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 56f1f0f..e3a855b 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -546,7 +546,6 @@
_(aten, pdist) \
_(aten, cdist) \
_(aten, permute) \
-_(aten, movedim) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index eede5f4..72cf483 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -284,6 +284,8 @@
_(aten, swapaxes_) \
_(aten, swapdims) \
_(aten, swapdims_) \
+ _(aten, movedim) \
+ _(aten, moveaxis) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 11889da..eda688a 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -2002,6 +2002,14 @@
return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
}
+Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
+ return at::movedim(self, src, dst);
+}
+
+Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) {
+ return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
+}
+
Tensor swapaxes(const Tensor& self, int64_t axis0, int64_t axis1) {
return self.transpose(axis0, axis1);
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3cc6f3a..8f237f9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3057,6 +3057,15 @@
use_c10_dispatcher: full
variants: function, method
+# moveaxis, alias for movedim
+- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+ use_c10_dispatcher: full
+ variants: function, method
+
# Only exposed from C++ -- in Python,
# we expose it as an attribute `T`, not a function.
#
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index d7b0af7..3f12004 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -461,6 +461,7 @@
:noindex:
.. automethod:: mode
.. automethod:: movedim
+ .. automethod:: moveaxis
.. automethod:: mul
.. automethod:: mul_
.. automethod:: multiply
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index b16f14d..d7c80de 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -92,6 +92,7 @@
index_select
masked_select
movedim
+ moveaxis
narrow
nonzero
reshape
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 5c8acd7..125fa7a 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -7240,15 +7240,16 @@
self.assertEqual(c.grad.stride(), (2, 1))
def test_movedim(self, device):
- x = torch.randn(4, 3, 2, 1, dtype=torch.double, device=device, requires_grad=True)
+ for fn in [torch.movedim, torch.moveaxis]:
+ x = torch.randn(4, 3, 2, 1, dtype=torch.double, device=device, requires_grad=True)
- # Positive axis
- gradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
- gradgradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
+ # Positive axis
+ gradcheck(lambda x: fn(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
+ gradgradcheck(lambda x: fn(x, (0, 1, 2, 3), (3, 2, 1, 0)), x)
- # Negative axis
- gradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
- gradgradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
+ # Negative axis
+ gradcheck(lambda x: fn(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
+ gradgradcheck(lambda x: fn(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x)
def _test_atleast(self, device, torch_fn):
# 0-dim
diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py
index 7ba2f48..8a9b7c3 100644
--- a/test/test_op_aliases.py
+++ b/test/test_op_aliases.py
@@ -161,6 +161,8 @@
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack,
lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))),
+ AliasInfo('moveaxis', torch.moveaxis, 'movedim', torch.movedim,
+ lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
)
# Placeholder test class for validating that aliases are correctly
diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py
index 3a7afa7..4332150 100644
--- a/test/test_shape_ops.py
+++ b/test/test_shape_ops.py
@@ -86,95 +86,97 @@
shape = self._rand_shape(4, min_size=5, max_size=10)
x = _generate_input(shape, dtype, device, False)
- # Invalid `source` and `destination` dimension
- with self.assertRaisesRegex(IndexError, "Dimension out of range"):
- torch.movedim(x, 5, 0)
+ for fn in [torch.movedim, torch.moveaxis]:
+ # Invalid `source` and `destination` dimension
+ with self.assertRaisesRegex(IndexError, "Dimension out of range"):
+ fn(x, 5, 0)
- with self.assertRaisesRegex(IndexError, "Dimension out of range"):
- torch.movedim(x, 0, 5)
+ with self.assertRaisesRegex(IndexError, "Dimension out of range"):
+ fn(x, 0, 5)
- # Mismatch in size of `source` and `destination`
- with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
- torch.movedim(x, (1, 0), (0, ))
+ # Mismatch in size of `source` and `destination`
+ with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
+ fn(x, (1, 0), (0, ))
- with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
- torch.movedim(x, (0, 0), (0, 1))
+ with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
+ fn(x, (0, 0), (0, 1))
- with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
- torch.movedim(x, (0, 1, 0), (0, 1, 2))
+ with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
+ fn(x, (0, 1, 0), (0, 1, 2))
- with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
- torch.movedim(x, (0, 1), (1, 1))
+ with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
+ fn(x, (0, 1), (1, 1))
- with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
- torch.movedim(x, (0, 1, 2), (1, 0, 1))
+ with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
+ fn(x, (0, 1, 2), (1, 0, 1))
@dtypes(torch.int64, torch.float, torch.complex128)
def test_movedim(self, device, dtype):
- for nd in range(5):
- shape = self._rand_shape(nd, min_size=5, max_size=10)
- x = _generate_input(shape, dtype, device, with_extremal=False)
- for random_negative in [True, False]:
- for src_dim, dst_dim in permutations(range(nd), r=2):
- random_prob = random.random()
+ for fn in [torch.moveaxis, torch.movedim]:
+ for nd in range(5):
+ shape = self._rand_shape(nd, min_size=5, max_size=10)
+ x = _generate_input(shape, dtype, device, with_extremal=False)
+ for random_negative in [True, False]:
+ for src_dim, dst_dim in permutations(range(nd), r=2):
+ random_prob = random.random()
- if random_negative and random_prob > 0.66:
- src_dim = src_dim - nd
- elif random_negative and random_prob > 0.33:
- dst_dim = dst_dim - nd
- elif random_negative:
- src_dim = src_dim - nd
- dst_dim = dst_dim - nd
+ if random_negative and random_prob > 0.66:
+ src_dim = src_dim - nd
+ elif random_negative and random_prob > 0.33:
+ dst_dim = dst_dim - nd
+ elif random_negative:
+ src_dim = src_dim - nd
+ dst_dim = dst_dim - nd
- # Integer `source` and `destination`
- torch_fn = partial(torch.movedim, source=src_dim, destination=dst_dim)
- np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
- self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+ # Integer `source` and `destination`
+ torch_fn = partial(fn, source=src_dim, destination=dst_dim)
+ np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
+ self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
- if nd == 0:
- continue
+ if nd == 0:
+ continue
- def make_index_negative(sequence, idx):
- sequence = list(sequence)
- sequence[random_idx] = sequence[random_idx] - nd
- return tuple(src_sequence)
+ def make_index_negative(sequence, idx):
+ sequence = list(sequence)
+ sequence[random_idx] = sequence[random_idx] - nd
+ return tuple(src_sequence)
- for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
- # Sequence `source` and `destination`
- dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
+ for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
+ # Sequence `source` and `destination`
+ dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
- # Randomly change a dim to a negative dim representation of itself.
- random_prob = random.random()
- if random_negative and random_prob > 0.66:
- random_idx = random.randint(0, len(src_sequence) - 1)
- src_sequence = make_index_negative(src_sequence, random_idx)
- elif random_negative and random_prob > 0.33:
- random_idx = random.randint(0, len(src_sequence) - 1)
- dst_sequence = make_index_negative(dst_sequence, random_idx)
- elif random_negative:
- random_idx = random.randint(0, len(src_sequence) - 1)
- dst_sequence = make_index_negative(dst_sequence, random_idx)
- random_idx = random.randint(0, len(src_sequence) - 1)
- src_sequence = make_index_negative(src_sequence, random_idx)
+ # Randomly change a dim to a negative dim representation of itself.
+ random_prob = random.random()
+ if random_negative and random_prob > 0.66:
+ random_idx = random.randint(0, len(src_sequence) - 1)
+ src_sequence = make_index_negative(src_sequence, random_idx)
+ elif random_negative and random_prob > 0.33:
+ random_idx = random.randint(0, len(src_sequence) - 1)
+ dst_sequence = make_index_negative(dst_sequence, random_idx)
+ elif random_negative:
+ random_idx = random.randint(0, len(src_sequence) - 1)
+ dst_sequence = make_index_negative(dst_sequence, random_idx)
+ random_idx = random.randint(0, len(src_sequence) - 1)
+ src_sequence = make_index_negative(src_sequence, random_idx)
- torch_fn = partial(torch.movedim, source=src_sequence, destination=dst_sequence)
- np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
- self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+ torch_fn = partial(fn, source=src_sequence, destination=dst_sequence)
+ np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
+ self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
- # Move dim to same position
- x = torch.randn(2, 3, 5, 7, 11)
- torch_fn = partial(torch.movedim, source=(0, 1), destination=(0, 1))
- np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
- self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+ # Move dim to same position
+ x = torch.randn(2, 3, 5, 7, 11)
+ torch_fn = partial(fn, source=(0, 1), destination=(0, 1))
+ np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
+ self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
- torch_fn = partial(torch.movedim, source=1, destination=1)
- np_fn = partial(np.moveaxis, source=1, destination=1)
- self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+ torch_fn = partial(fn, source=1, destination=1)
+ np_fn = partial(np.moveaxis, source=1, destination=1)
+ self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
- # Empty Sequence
- torch_fn = partial(torch.movedim, source=(), destination=())
- np_fn = partial(np.moveaxis, source=(), destination=())
- self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
+ # Empty Sequence
+ torch_fn = partial(fn, source=(), destination=())
+ np_fn = partial(np.moveaxis, source=(), destination=())
+ self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
@dtypes(torch.float, torch.bool)
def test_diag(self, device, dtype):
diff --git a/test/test_view_ops.py b/test/test_view_ops.py
index d4e59a3..15f1bcd 100644
--- a/test/test_view_ops.py
+++ b/test/test_view_ops.py
@@ -535,11 +535,12 @@
out[idx_1, idx_2] = random.random()
self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
- op = partial(torch.movedim, source=(0, 1), destination=(1, 0))
- run_test(device, op)
+ for fn in [torch.movedim, torch.moveaxis]:
+ op = partial(fn, source=(0, 1), destination=(1, 0))
+ run_test(device, op)
- op = partial(torch.movedim, source=0, destination=1)
- run_test(device, op)
+ op = partial(fn, source=0, destination=1)
+ run_test(device, op)
class TestOldViewOps(TestCase):
def test_ravel(self, device):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 99f504a..87bbf38 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2403,6 +2403,12 @@
See :func:`torch.movedim`
""")
+add_docstr_all('moveaxis', r"""
+moveaxis(source, destination) -> Tensor
+
+See :func:`torch.moveaxis`
+""")
+
add_docstr_all('mul', r"""
mul(value) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index dd4be74..7852f3b 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5749,6 +5749,43 @@
[[-0.8437, 0.1727, -0.1398]]])
""".format(**common_args))
+add_docstr(torch.moveaxis, r"""
+moveaxis(input, source, destination) -> Tensor
+
+Alias for :func:`torch.movedim`.
+
+This function is equivalent to NumPy's moveaxis function.
+
+Examples::
+
+ >>> t = torch.randn(3,2,1)
+ >>> t
+ tensor([[[-0.3362],
+ [-0.8437]],
+
+ [[-0.9627],
+ [ 0.1727]],
+
+ [[ 0.5173],
+ [-0.1398]]])
+ >>> torch.moveaxis(t, 1, 0).shape
+ torch.Size([2, 3, 1])
+ >>> torch.moveaxis(t, 1, 0)
+ tensor([[[-0.3362],
+ [-0.9627],
+ [ 0.5173]],
+
+ [[-0.8437],
+ [ 0.1727],
+ [-0.1398]]])
+ >>> torch.moveaxis(t, (1, 2), (0, 1)).shape
+ torch.Size([2, 1, 3])
+ >>> torch.moveaxis(t, (1, 2), (0, 1))
+ tensor([[[-0.3362, -0.9627, 0.5173]],
+
+ [[-0.8437, 0.1727, -0.1398]]])
+""".format(**common_args))
+
add_docstr(torch.swapdims, r"""
swapdims(input, dim0, dim1) -> Tensor
diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp
index 2f463f7..e032eeb 100644
--- a/torch/csrc/jit/passes/normalize_ops.cpp
+++ b/torch/csrc/jit/passes/normalize_ops.cpp
@@ -103,6 +103,7 @@
{aten::swapdims_, aten::transpose_},
{aten::swapaxes, aten::transpose},
{aten::swapaxes_, aten::transpose_},
+ {aten::moveaxis, aten::movedim},
};
return alias_map;
}
diff --git a/torch/overrides.py b/torch/overrides.py
index f6a4937..f7b9bed 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -522,6 +522,7 @@
torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
+ torch.moveaxis: lambda input, source, destination: -1,
torch.mul: lambda input, other, out=None: -1,
torch.multiply: lambda input, other, out=None: -1,
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,