flash_attention_helper mitigation: pass contiguous inputs (#85135)

There appears to be a transient issue with respect to non-contiguous inputs in flash_attn and thus we're passing contiguous inputs to mitigate it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85135
Approved by: https://github.com/drisspg
diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
index fc12ed1..35a1c83 100644
--- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
@@ -296,18 +296,26 @@
     int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
 
     // For the packed case we need to set the output size for dim 2 to 1
-    auto atten_size = get_nested_size_tensor(query);
+    auto atten_size = get_nested_size_tensor(query).clone();
     atten_size.index({at::indexing::Slice(), 1}) = 1;
 
     auto qkv_buffer_reshaped =
-        get_buffer(query).view({Nnz_q, 3, num_heads, head_dim});
+        get_buffer(query).view({Nnz_q, 3, num_heads, head_dim}).transpose(0, 1).contiguous();
+
+    auto i0 = qkv_buffer_reshaped[0];
+    auto i1 = qkv_buffer_reshaped[1];
+    auto i2 = qkv_buffer_reshaped[2];
+
+    TORCH_CHECK(i0.is_contiguous());
+    TORCH_CHECK(i1.is_contiguous());
+    TORCH_CHECK(i2.is_contiguous());
 
     // If we are passing in query, key, value all the same tensors then we have
     // packed them into one tensor and need to slice for flash attention
     Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
-        qkv_buffer_reshaped.index({at::indexing::Slice(), 0}),
-        qkv_buffer_reshaped.index({at::indexing::Slice(), 1}),
-        qkv_buffer_reshaped.index({at::indexing::Slice(), 2}),
+        i0,
+        i1,
+        i2,
         cumulative_sequence_length_q,
         cumulative_sequence_length_q,
         max_seqlen_batch_q,