[NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index d9fcd7c..49cb422 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -5955,6 +5955,37 @@
self.assertFalse(clone.is_contiguous())
check_nt_equality(detached, transposed)
+ def test_permute(self, device):
+ nt = random_nt_from_dims(
+ [2, None, 3, 5], device, torch.float32, layout=torch.jagged
+ )
+ nt_shape = nt.shape
+ nt_inner_shape = nt.values().shape
+ with self.assertRaisesRegex(
+ ValueError,
+ r"permute\(\): number of dimensions in the tensor input \(4\) "
+ + r"does not match the length of the desired ordering of dimensions \(3\).",
+ ):
+ nt.permute(0, 2, 1)
+ with self.assertRaisesRegex(
+ ValueError, r"permute\(\): duplicate dims are not allowed."
+ ):
+ nt.permute(0, 2, -2, 3)
+ with self.assertRaisesRegex(
+ ValueError, "Permute is not supported on the batch dimension for jagged NT"
+ ):
+ nt.permute(1, 0, 2, 3)
+ nt_permute = nt.permute(0, 2, 1, -1)
+ self.assertEqual(
+ nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
+ )
+ self.assertEqual(
+ nt_permute.values().shape,
+ (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
+ )
+ self.assertEqual(nt_permute._ragged_idx, 2)
+ self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
+
def test_to_dtype(self, device):
nt = random_nt_from_dims(
[2, None, 3], device, torch.float32, layout=torch.jagged
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 0d20327..ed9a54f 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -1159,6 +1159,44 @@
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
+def permute_default(func, *args, **kwargs):
+ _, new_kwargs = normalize_function( # type: ignore[misc]
+ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+ )
+ inp = new_kwargs.pop("input")
+ dims = new_kwargs.pop("dims")
+ inp_kwargs = extract_kwargs(inp)
+ inp_dim = len(inp._size)
+
+ # The first two checks are the same as the checks in the normal permute implementation
+ if inp_dim != len(dims):
+ raise ValueError(
+ f"permute(): number of dimensions in the tensor input ({inp_dim}) "
+ + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
+ )
+
+ from torch._prims_common import canonicalize_dims
+
+ canonicalized_dims = canonicalize_dims(inp_dim, dims)
+
+ if len(canonicalized_dims) != len(set(canonicalized_dims)):
+ raise ValueError("permute(): duplicate dims are not allowed.")
+
+ if inp._lengths is not None:
+ raise ValueError(
+ "permute(): not supported on jagged layout nested tensor with holes"
+ )
+ if canonicalized_dims[0] != 0:
+ raise ValueError(
+ "Permute is not supported on the batch dimension for jagged NT"
+ )
+ inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
+ inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
+ new_kwargs["dims"] = inner_dims
+ return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
+
+
@register_jagged_func(
[torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
"self: jt_all, size: any",