Fix SDPA for SAM (#115636)
Addresses the regression for Segment Anything Fast in https://github.com/pytorch-labs/segment-anything-fast/issues/99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115636
Approved by: https://github.com/soulitzer, https://github.com/ani300
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 4684c90..ac215b6 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -9,6 +9,7 @@
import numpy as np
import torch
import torch.nn
+import torch.nn.functional as F
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import (
dtypes,
@@ -133,14 +134,14 @@
# Alternate approach to generating a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
-def random_nt_from_dims(dims, device=None, dtype=None, requires_grad=False):
+def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False):
sizes = [
[d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]]
for d in range(dims[0])
]
return torch.nested.nested_tensor([
torch.randn(*size) for size in sizes
- ], device=device, dtype=dtype, requires_grad=requires_grad)
+ ], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad)
# Creates an NT matching another NT's number of components and
@@ -3610,6 +3611,23 @@
self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
+ @dtypes(torch.float32, torch.double, torch.half)
+ def test_sdpa_with_constant_sequence_length(self, device, dtype):
+ # shape (B, P*, S, D)
+ # B: batch size
+ # P*: ragged number of prompts
+ # S: (constant) sequence length
+ # D: embedding size
+ query = random_nt_from_dims(
+ [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged)
+ key = random_nt_from_similar(query)
+ value = random_nt_from_similar(query)
+ output = F.scaled_dot_product_attention(query, key, value)
+ self.assertTrue(isinstance(output, NestedTensor))
+
+ # should be equivalent to just running the buffers through
+ output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values)
+ self.assertEqual(output._values, output_dense)
instantiate_parametrized_tests(TestNestedTensor)
diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py
index f4d7d67..4dc28e5 100644
--- a/torch/nested/_internal/sdpa.py
+++ b/torch/nested/_internal/sdpa.py
@@ -4,6 +4,7 @@
import torch
import torch.nn
+import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
@@ -55,10 +56,10 @@
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
- if query._ragged_idx != 2 or key._ragged_idx != 2 or value._ragged_idx != 2:
+ if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
raise ValueError(
- f"Expected query, key, and value to all be be jagged at dimension 2, but got query._ragged_idx: "
- f"{query._ragged_idx}, key._ragged_idx: {key._ragged_idx} and value._ragged_idx: {value._ragged_idx} instead."
+ f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
+ f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
)
if attn_mask is not None:
# TODO: Figure out whether masks are actually supported for this layout or not
@@ -622,6 +623,33 @@
scale=None,
):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
+ # for mypy, ugh
+ assert (
+ isinstance(query, NestedTensor)
+ and isinstance(key, NestedTensor)
+ and isinstance(value, NestedTensor)
+ )
+
+ # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
+ # second batch dim instead). For this case, we can just send the dense buffers through
+ # vanilla SDPA.
+ if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
+ from torch.nested._internal.ops import extract_kwargs
+
+ output = F.scaled_dot_product_attention(
+ query._values,
+ key._values,
+ value._values,
+ attn_mask=(
+ attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
+ ),
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+
+ return NestedTensor(output, **extract_kwargs(query))
+
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(