Fix SDPA for SAM (#115636)

Addresses the regression for Segment Anything Fast in https://github.com/pytorch-labs/segment-anything-fast/issues/99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115636
Approved by: https://github.com/soulitzer, https://github.com/ani300
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 4684c90..ac215b6 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -9,6 +9,7 @@
 import numpy as np
 import torch
 import torch.nn
+import torch.nn.functional as F
 from torch.testing._internal.common_cuda import SM80OrLater
 from torch.testing._internal.common_device_type import (
     dtypes,
@@ -133,14 +134,14 @@
 # Alternate approach to generating a random NT.
 # dims should be something like [5, None, 10], with None indicating that a
 # random ragged structure should be used
-def random_nt_from_dims(dims, device=None, dtype=None, requires_grad=False):
+def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False):
     sizes = [
         [d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]]
         for d in range(dims[0])
     ]
     return torch.nested.nested_tensor([
         torch.randn(*size) for size in sizes
-    ], device=device, dtype=dtype, requires_grad=requires_grad)
+    ], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad)
 
 
 # Creates an NT matching another NT's number of components and
@@ -3610,6 +3611,23 @@
         self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
         self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
 
+    @dtypes(torch.float32, torch.double, torch.half)
+    def test_sdpa_with_constant_sequence_length(self, device, dtype):
+        # shape (B, P*, S, D)
+        # B: batch size
+        # P*: ragged number of prompts
+        # S: (constant) sequence length
+        # D: embedding size
+        query = random_nt_from_dims(
+            [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged)
+        key = random_nt_from_similar(query)
+        value = random_nt_from_similar(query)
+        output = F.scaled_dot_product_attention(query, key, value)
+        self.assertTrue(isinstance(output, NestedTensor))
+
+        # should be equivalent to just running the buffers through
+        output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values)
+        self.assertEqual(output._values, output_dense)
 
 
 instantiate_parametrized_tests(TestNestedTensor)
diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py
index f4d7d67..4dc28e5 100644
--- a/torch/nested/_internal/sdpa.py
+++ b/torch/nested/_internal/sdpa.py
@@ -4,6 +4,7 @@
 
 import torch
 import torch.nn
+import torch.nn.functional as F
 from torch.backends.cuda import (
     can_use_efficient_attention,
     can_use_flash_attention,
@@ -55,10 +56,10 @@
             f"Expected query, key, and value to all be  at least 2 dimensional, but got query.dim: "
             f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
         )
-    if query._ragged_idx != 2 or key._ragged_idx != 2 or value._ragged_idx != 2:
+    if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
         raise ValueError(
-            f"Expected query, key, and value to all be be jagged at dimension 2, but got query._ragged_idx: "
-            f"{query._ragged_idx}, key._ragged_idx: {key._ragged_idx} and value._ragged_idx: {value._ragged_idx} instead."
+            f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
+            f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
         )
     if attn_mask is not None:
         # TODO: Figure out whether masks are actually supported for this layout or not
@@ -622,6 +623,33 @@
     scale=None,
 ):
     _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
+    # for mypy, ugh
+    assert (
+        isinstance(query, NestedTensor)
+        and isinstance(key, NestedTensor)
+        and isinstance(value, NestedTensor)
+    )
+
+    # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
+    # second batch dim instead). For this case, we can just send the dense buffers through
+    # vanilla SDPA.
+    if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
+        from torch.nested._internal.ops import extract_kwargs
+
+        output = F.scaled_dot_product_attention(
+            query._values,
+            key._values,
+            value._values,
+            attn_mask=(
+                attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
+            ),
+            dropout_p=dropout_p,
+            is_causal=is_causal,
+            scale=scale,
+        )
+
+        return NestedTensor(output, **extract_kwargs(query))
+
     compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
 
     backend_choice = _select_sdp_backend(