[NJT]Add permute ops support (#135336)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index d9fcd7c..49cb422 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -5955,6 +5955,37 @@
         self.assertFalse(clone.is_contiguous())
         check_nt_equality(detached, transposed)
 
+    def test_permute(self, device):
+        nt = random_nt_from_dims(
+            [2, None, 3, 5], device, torch.float32, layout=torch.jagged
+        )
+        nt_shape = nt.shape
+        nt_inner_shape = nt.values().shape
+        with self.assertRaisesRegex(
+            ValueError,
+            r"permute\(\): number of dimensions in the tensor input \(4\) "
+            + r"does not match the length of the desired ordering of dimensions \(3\).",
+        ):
+            nt.permute(0, 2, 1)
+        with self.assertRaisesRegex(
+            ValueError, r"permute\(\): duplicate dims are not allowed."
+        ):
+            nt.permute(0, 2, -2, 3)
+        with self.assertRaisesRegex(
+            ValueError, "Permute is not supported on the batch dimension for jagged NT"
+        ):
+            nt.permute(1, 0, 2, 3)
+        nt_permute = nt.permute(0, 2, 1, -1)
+        self.assertEqual(
+            nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
+        )
+        self.assertEqual(
+            nt_permute.values().shape,
+            (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
+        )
+        self.assertEqual(nt_permute._ragged_idx, 2)
+        self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
+
     def test_to_dtype(self, device):
         nt = random_nt_from_dims(
             [2, None, 3], device, torch.float32, layout=torch.jagged
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 0d20327..ed9a54f 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -1159,6 +1159,44 @@
     return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
 
 
+@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
+def permute_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    dims = new_kwargs.pop("dims")
+    inp_kwargs = extract_kwargs(inp)
+    inp_dim = len(inp._size)
+
+    # The first two checks are the same as the checks in the normal permute implementation
+    if inp_dim != len(dims):
+        raise ValueError(
+            f"permute(): number of dimensions in the tensor input ({inp_dim}) "
+            + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
+        )
+
+    from torch._prims_common import canonicalize_dims
+
+    canonicalized_dims = canonicalize_dims(inp_dim, dims)
+
+    if len(canonicalized_dims) != len(set(canonicalized_dims)):
+        raise ValueError("permute(): duplicate dims are not allowed.")
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "permute(): not supported on jagged layout nested tensor with holes"
+        )
+    if canonicalized_dims[0] != 0:
+        raise ValueError(
+            "Permute is not supported on the batch dimension for jagged NT"
+        )
+    inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
+    inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
+    new_kwargs["dims"] = inner_dims
+    return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
+
+
 @register_jagged_func(
     [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
     "self: jt_all, size: any",