[NestedTensor] Support ragged_idx != 1 in pointwise ops (#118157)

This PR allows pointwise ops to operate on tensors with ragged_idx != 1. It does this by passing the ragged_idx metadata into the construction of the returned NestedTensor when computing pointwise ops. The assumption is that: pointwise ops can operate directly on the values tensors, and the resulting tensor should have all the same metadata properties as the input tensors. For binary ops, a test is added to verify that adding two tensors with different ragged_idx cannot be added.

Previously:
* unary pointwise ops would error out when performed on nested tensors with ragged_idx != 1
* binary pointwise ops would produce tensors with nonsense shapes

Differential Revision: [D53032641](https://our.internmc.facebook.com/intern/diff/D53032641)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118157
Approved by: https://github.com/jbschlosser
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index fcdf6fe..942185f 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -3054,6 +3054,32 @@
 
         gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
 
+    def test_unary_pointwise_transposed_inputs(self, device):
+        a, b, c = (
+            torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+        )
+
+        nt, _ = jagged_from_list([a.detach(), b.detach(), c.detach()], None)
+        nt_t = nt.transpose(1, 2)
+        self.assertFalse(nt_t.is_contiguous())
+        out = torch.nn.functional.silu(nt_t.sin().cos())
+        self.assertEqual(out.is_contiguous(), torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous())
+
+        self.assertEqual(nt_t.shape, out.shape)
+
+        a, b, c = (
+            torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+        )
+
+        def grad_test_func(a, b, c):
+            nt, _ = jagged_from_list([a, b, c], None)
+            nt_t = nt.transpose(1, 2)
+            out = torch.nn.functional.silu(nt_t.sin().cos())
+            return buffer_from_jagged(out)
+
+        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+
+
     def test_binary_pointwise(self, device):
         a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
         b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
@@ -3078,6 +3104,43 @@
 
         gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
 
+    def test_binary_pointwise_transposed(self, device):
+        a, b, c = (
+            torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
+        )
+
+        nt1, offsets = jagged_from_list([a, b, c], None)
+        nt2, offsets = jagged_from_list([a, b, c], offsets)
+
+        nt1_t = nt1.transpose(1, 2)
+        nt2_t = nt2.transpose(1, 2)
+
+        out = nt1_t * nt2_t
+        self.assertFalse(nt1_t.is_contiguous())
+        self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
+        self.assertEqual(out.shape, nt1_t.shape)
+
+        self.assertRaisesRegex(
+            RuntimeError,
+            "cannot call binary pointwise function mul.Tensor with inputs of shapes",
+            lambda: nt1 * nt2_t,
+        )
+
+        a, b, c = (
+            torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+        )
+
+        # Correct usage: chain the calls using the same offsets tensor object
+        def grad_test_func(a, b, c):
+            nt1, offsets = jagged_from_list([a, b, c], None)
+            nt2, offsets = jagged_from_list([a, b, c], offsets)
+            nt1_t = nt1.transpose(1, 2)
+            nt2_t = nt2.transpose(1, 2)
+            out = nt1_t * nt2_t
+            return buffer_from_jagged(out)
+
+        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+
     def test_split(self, device):
         a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
         b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
@@ -3555,9 +3618,25 @@
         nt, _ = jagged_from_list([a, b, c], None)
         # transpose ragged dim
         transposed = nt.transpose(1, 2)
-        # pointwise ops are not supported on ragged dim transposed jagged layout NTs
-        with self.assertRaisesRegex(ValueError, "expected .* to be a contiguous jagged layout"):
-            clone = transposed.clone()
+        self.assertFalse(transposed.is_contiguous())
+        clone = transposed.clone()
+
+        def check_nt_equality(x, y):
+            self.assertEqual(x.values(), y.values())
+            self.assertEqual(x.offsets(), y.offsets())
+            self.assertEqual(x._ragged_idx, y._ragged_idx)
+            self.assertEqual(x.shape, y.shape)
+
+        self.assertFalse(clone.is_contiguous())
+        check_nt_equality(clone, transposed)
+
+        clone_contig = transposed.clone(memory_format=torch.contiguous_format)
+        self.assertTrue(clone_contig.is_contiguous())
+        check_nt_equality(clone_contig, transposed)
+
+        detached = transposed.detach()
+        self.assertFalse(clone.is_contiguous())
+        check_nt_equality(detached, transposed)
 
     # Note 1: Math fallback doesn't work with bfloat16 on CUDA
     # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
@@ -3679,7 +3758,6 @@
             if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
                 check_forward_backward()
 
-
     # This requires NT -> NT views to work in inductor, which is a TODO
     @unittest.expectedFailure  # noqa: E301
     @onlyCUDA
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 3aacd3e..8d68b68 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -191,7 +191,7 @@
         # Assume there aren't additional tensors that aren't the "unary/binary" args
         num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
         if num_tensor_args == 1:
-            check_schema("self: jt, ...", func, *args, **kwargs)
+            check_schema("self: jt_all, ...", func, *args, **kwargs)
             return functools.partial(jagged_unary_pointwise, func)
         elif num_tensor_args == 2:
             check_schema("lhs: any, rhs: any", func, *args, **kwargs)
@@ -204,6 +204,7 @@
     kwargs = {
         "offsets": arg.offsets(),
         "_metadata_cache": arg._metadata_cache,
+        "_ragged_idx": arg._ragged_idx,
     }
     return kwargs
 
@@ -374,10 +375,6 @@
     if inp.lengths() is not None:
         return False
 
-    # If jagged dim is not 1 it's not contiguous
-    if inp._ragged_idx != 1:
-        return False
-
     new_kwargs["memory_format"] = new_kwargs.get(
         "memory_format", torch.contiguous_format
     )
@@ -446,7 +443,7 @@
         torch.ops.aten.randn_like.default,
         torch.ops.aten.detach.default,
     ],
-    "self: jt",
+    "self: jt_all",
 )(jagged_unary_pointwise)
 
 
@@ -791,13 +788,14 @@
             to_dim = dim1
         else:
             to_dim = dim0
+        inp_kwargs = extract_kwargs(inp)
+        inp_kwargs["_ragged_idx"] = to_dim
         return NestedTensor(
             inp.values().transpose(
                 _outer_to_inner_dim(len(inp._size), dim0),
                 _outer_to_inner_dim(len(inp._size), dim1),
             ),
-            **extract_kwargs(inp),
-            _ragged_idx=to_dim,
+            **inp_kwargs,
         )
 
     new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")