flash_attention_helper mitigation: pass contiguous inputs (#85135)
There appears to be a transient issue with respect to non-contiguous inputs in flash_attn and thus we're passing contiguous inputs to mitigate it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85135
Approved by: https://github.com/drisspg
diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
index fc12ed1..35a1c83 100644
--- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp
@@ -296,18 +296,26 @@
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
// For the packed case we need to set the output size for dim 2 to 1
- auto atten_size = get_nested_size_tensor(query);
+ auto atten_size = get_nested_size_tensor(query).clone();
atten_size.index({at::indexing::Slice(), 1}) = 1;
auto qkv_buffer_reshaped =
- get_buffer(query).view({Nnz_q, 3, num_heads, head_dim});
+ get_buffer(query).view({Nnz_q, 3, num_heads, head_dim}).transpose(0, 1).contiguous();
+
+ auto i0 = qkv_buffer_reshaped[0];
+ auto i1 = qkv_buffer_reshaped[1];
+ auto i2 = qkv_buffer_reshaped[2];
+
+ TORCH_CHECK(i0.is_contiguous());
+ TORCH_CHECK(i1.is_contiguous());
+ TORCH_CHECK(i2.is_contiguous());
// If we are passing in query, key, value all the same tensors then we have
// packed them into one tensor and need to slice for flash attention
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
- qkv_buffer_reshaped.index({at::indexing::Slice(), 0}),
- qkv_buffer_reshaped.index({at::indexing::Slice(), 1}),
- qkv_buffer_reshaped.index({at::indexing::Slice(), 2}),
+ i0,
+ i1,
+ i2,
cumulative_sequence_length_q,
cumulative_sequence_length_q,
max_seqlen_batch_q,