[cuDNN][Flash Attention] Minor cleanup for cuDNN SDPA (#120750)

Cleaning up before hopefully starting work on backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120750
Approved by: https://github.com/Skylion007, https://github.com/drisspg
diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp
index 2a03582..c3f5f63 100644
--- a/aten/src/ATen/native/cudnn/MHA.cpp
+++ b/aten/src/ATen/native/cudnn/MHA.cpp
@@ -235,7 +235,7 @@
       fe::graph::Tensor_attributes()
           .set_name("Q")
           .set_dim(
-              std::vector<int64_t>(params.q_dim.begin(), params.q_stride.end()))
+              std::vector<int64_t>(params.q_dim.begin(), params.q_dim.end()))
           .set_stride(std::vector<int64_t>(
               params.q_stride.begin(), params.q_stride.end())));
   auto K = mha_graph->tensor(
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 0cfe6da..961c706 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -14571,7 +14571,7 @@
     CUDA: _scaled_dot_product_efficient_attention_backward_cuda
   tags: nondeterministic_seeded
 
-- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
+- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
   dispatch:
     CUDA: _scaled_dot_product_cudnn_attention_cuda
   tags: nondeterministic_seeded
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index af4924c..421bc83 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -575,6 +575,7 @@
     switch (backend) {
       case SDPBackend::cudnn_attention:
         if (sdp::can_use_cudnn_attention(kernel_params, print_debug)) {
+              TORCH_WARN("USING CUDNN SDPA");
               return SDPBackend::cudnn_attention;
         }
         break;
diff --git a/test/test_transformers.py b/test/test_transformers.py
index c34def1..af14b06 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -2769,7 +2769,9 @@
         query_fudge_factor = 4
         grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
 
-        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad)
+        key_fudge_factor = 2
+        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
+
         value_fudge_factor = 2
         grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)