make flash_attn_bw impl correct w.r.t. meta when k and v have different strides (#119500)

`dv = at::empty_like(k)` and `dv = at::empty_like(v)` can be materially different, because `empty_like` tries to preserve the strides of the input when possible. So if `k` is contiguous, but `v`, is transposed, then before this PR, `dv` would be computed to be contiguous.

Alternatively, we could change the meta implementation of `aten._scaled_dot_product_flash_attention` to this:
```
    grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
    grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
    grad_v = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
    return grad_q, grad_k, grad_v
```

But (I think?) the logic in the sdpa backward impl was a typo.

I noticed this because changing the meta formula as above was enough to fix the issue with the `aot_eager` backend in this [link](https://github.com/pytorch/pytorch/issues/116935#issuecomment-1914310523).

A minimal repro that I made looks like this:
```
import torch

# in this repro, "grad_out" and "value" are transposed tensors,
# but "key" and "value" are contiguous
a = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
b = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
c = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
d = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
e = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
f = torch.randn(2, 16, 513, device='cuda')
g = None
h = None
i = 513
j = 513
k = 0.0
l = False
m = torch.tensor(1, dtype=torch.int64)
n = torch.tensor(1, dtype=torch.int64)

out1_ref, out2_ref, out3_ref = torch.ops.aten._scaled_dot_product_flash_attention_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)

from torch._meta_registrations import meta__scaled_dot_product_flash_backward
out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)

# prints True True
print(out1_ref.is_contiguous())
print(out1_test.is_contiguous())

# prints True True
print(out2_ref.is_contiguous())
print(out2_test.is_contiguous())

# prints True False
print(out3_ref.is_contiguous())
print(out3_test.is_contiguous())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119500
Approved by: https://github.com/drisspg, https://github.com/ezyang, https://github.com/Skylion007
diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
index 4d23e9f..07c9f7e 100644
--- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
+++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
@@ -800,7 +800,7 @@
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
         CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
     } else {
-        dv = at::empty_like(k);
+        dv = at::empty_like(v);
     }
 
     // const at::Tensor& dout_padded = dout;
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 9b2fbaa..c00ec24 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -28,13 +28,15 @@
 
 import torch._functorch.config
 import torch.library
-
 from torch import nn
 from torch._dynamo.debug_utils import same_two_models
 from torch._dynamo.testing import CompileCounter, rand_strided, same
 from torch.nn import functional as F
+
+from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
 from torch.testing._internal.common_utils import (
     disable_translation_validation_if_dynamic_shapes,
+    TEST_WITH_ROCM,
 )
 
 
@@ -3998,6 +4000,57 @@
                 # frame_count should stay at 1.
                 self.assertEqual(cnt.frame_count, 1)
 
+    @unittest.skipIf(
+        TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
+        "flash attention not supported",
+    )
+    def test_flash_attn_backward_mixed_strides(self):
+        # in this repro, "grad_out" and "value" are transposed tensors,
+        # but "key" and "value" are contiguous
+        def gen_inputs(device):
+            return (
+                torch.randn(
+                    2, 513, 16, 64, dtype=torch.float16, device=device
+                ).transpose(1, 2),
+                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+                torch.randn(
+                    2, 513, 16, 64, dtype=torch.float16, device=device
+                ).transpose(1, 2),
+                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
+                torch.randn(2, 16, 513, device=device),
+                None,
+                None,
+                513,
+                513,
+                0.0,
+                False,
+                torch.tensor(1, dtype=torch.int64),
+                torch.tensor(1, dtype=torch.int64),
+            )
+
+        inps_cuda = gen_inputs("cuda")
+        inps_meta = gen_inputs("meta")
+        (
+            out1_ref,
+            out2_ref,
+            out3_ref,
+        ) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
+            *inps_cuda, scale=0.125
+        )
+        from torch._meta_registrations import meta__scaled_dot_product_flash_backward
+
+        out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
+            *inps_meta, scale=0.125
+        )
+
+        self.assertEqual(out1_ref.shape, out1_test.shape)
+        self.assertEqual(out1_ref.stride(), out1_test.stride())
+        self.assertEqual(out2_ref.shape, out2_test.shape)
+        self.assertEqual(out2_ref.stride(), out2_test.stride())
+        self.assertEqual(out3_ref.shape, out3_test.shape)
+        self.assertEqual(out3_ref.stride(), out3_test.stride())
+
     def test_user_ctor_ctx_manager(self):
         class UserCtxManager:
             def __enter__(self):