make flash_attn_bw impl correct w.r.t. meta when k and v have different strides (#119500)
`dv = at::empty_like(k)` and `dv = at::empty_like(v)` can be materially different, because `empty_like` tries to preserve the strides of the input when possible. So if `k` is contiguous, but `v`, is transposed, then before this PR, `dv` would be computed to be contiguous.
Alternatively, we could change the meta implementation of `aten._scaled_dot_product_flash_attention` to this:
```
grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
grad_v = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
return grad_q, grad_k, grad_v
```
But (I think?) the logic in the sdpa backward impl was a typo.
I noticed this because changing the meta formula as above was enough to fix the issue with the `aot_eager` backend in this [link](https://github.com/pytorch/pytorch/issues/116935#issuecomment-1914310523).
A minimal repro that I made looks like this:
```
import torch
# in this repro, "grad_out" and "value" are transposed tensors,
# but "key" and "value" are contiguous
a = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
b = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
c = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
d = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
e = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
f = torch.randn(2, 16, 513, device='cuda')
g = None
h = None
i = 513
j = 513
k = 0.0
l = False
m = torch.tensor(1, dtype=torch.int64)
n = torch.tensor(1, dtype=torch.int64)
out1_ref, out2_ref, out3_ref = torch.ops.aten._scaled_dot_product_flash_attention_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)
from torch._meta_registrations import meta__scaled_dot_product_flash_backward
out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)
# prints True True
print(out1_ref.is_contiguous())
print(out1_test.is_contiguous())
# prints True True
print(out2_ref.is_contiguous())
print(out2_test.is_contiguous())
# prints True False
print(out3_ref.is_contiguous())
print(out3_test.is_contiguous())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119500
Approved by: https://github.com/drisspg, https://github.com/ezyang, https://github.com/Skylion007
diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
index 4d23e9f..07c9f7e 100644
--- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
+++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
@@ -800,7 +800,7 @@
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
} else {
- dv = at::empty_like(k);
+ dv = at::empty_like(v);
}
// const at::Tensor& dout_padded = dout;
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 9b2fbaa..c00ec24 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -28,13 +28,15 @@
import torch._functorch.config
import torch.library
-
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import CompileCounter, rand_strided, same
from torch.nn import functional as F
+
+from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
disable_translation_validation_if_dynamic_shapes,
+ TEST_WITH_ROCM,
)
@@ -3998,6 +4000,57 @@
# frame_count should stay at 1.
self.assertEqual(cnt.frame_count, 1)
+ @unittest.skipIf(
+ TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
+ "flash attention not supported",
+ )
+ def test_flash_attn_backward_mixed_strides(self):
+ # in this repro, "grad_out" and "value" are transposed tensors,
+ # but "key" and "value" are contiguous
+ def gen_inputs(device):
+ return (
+ torch.randn(
+ 2, 513, 16, 64, dtype=torch.float16, device=device
+ ).transpose(1, 2),
+ torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+ torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+ torch.randn(
+ 2, 513, 16, 64, dtype=torch.float16, device=device
+ ).transpose(1, 2),
+ torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+ torch.randn(2, 16, 513, device=device),
+ None,
+ None,
+ 513,
+ 513,
+ 0.0,
+ False,
+ torch.tensor(1, dtype=torch.int64),
+ torch.tensor(1, dtype=torch.int64),
+ )
+
+ inps_cuda = gen_inputs("cuda")
+ inps_meta = gen_inputs("meta")
+ (
+ out1_ref,
+ out2_ref,
+ out3_ref,
+ ) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
+ *inps_cuda, scale=0.125
+ )
+ from torch._meta_registrations import meta__scaled_dot_product_flash_backward
+
+ out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
+ *inps_meta, scale=0.125
+ )
+
+ self.assertEqual(out1_ref.shape, out1_test.shape)
+ self.assertEqual(out1_ref.stride(), out1_test.stride())
+ self.assertEqual(out2_ref.shape, out2_test.shape)
+ self.assertEqual(out2_ref.stride(), out2_test.stride())
+ self.assertEqual(out3_ref.shape, out3_test.shape)
+ self.assertEqual(out3_ref.stride(), out3_test.stride())
+
def test_user_ctor_ctx_manager(self):
class UserCtxManager:
def __enter__(self):