[BUG FIX] Refactor _scale_attn_mask_fusion_kernel to Use Runtime Argument Instead of Template Parameter (#132434)
**Description**
**_[BUG FIX]_**
This PR fixes a bug which happens during compilation with GCC 11.4 compiler in the FlashAttentionKernel.cpp file. This issue doesn't seem to be with PyTorch main branch but gets introduced with our SVE PR changes (https://github.com/pytorch/pytorch/pull/119571 ) + PyTorch main.
See the CI Pipeline failing in our PR:
https://github.com/pytorch/pytorch/actions/runs/9895714768/job/27336251795?pr=119571
```
/var/lib/jenkins/workspace/build/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp.SVE256.cpp
during RTL pass: expand
In file included from /var/lib/jenkins/workspace/build/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp.SVE256.cpp:1:
/var/lib/jenkins/workspace/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp: In lambda function:
/var/lib/jenkins/workspace/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp:290:57: internal compiler error: in emit_move_insn, at expr.c:3821
290 | at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
| ^
0xffffb03f73fb __libc_start_call_main
../sysdeps/nptl/libc_start_call_main.h:58
0xffffb03f74cb __libc_start_main_impl
../csu/libc-start.c:392
Please submit a full bug report,
with preprocessed source if appropriate.
Please include the complete backtrace with any bug report.
See <file:///usr/share/doc/gcc-11/README.Bugs> for instructions.
[5731/6839] Building CXX object caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/CatKernel.cpp.SVE256.cpp.o
[5732/6839] Building CXX object caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/GridSamplerKernel.cpp.SVE256.cpp.o
```
This issue with compilation only happens with GCC 11.4 and works well with the latest GCC 12.3 compiler and also the Clang compiler. The issue is related to the check for `is_b_stride_zero` introduced as a template parameter (compile time check complexity) in the following commit: https://github.com/pytorch/pytorch/commit/5da428d9ebab50be5974e228e8a83c54af02ecb6 which was added recently into FlashAttentionKernel.cpp file.
This PR fixes the above compilation failure with GCC 11.4 compiler.
cc : @Valentine233 @yanbing-j @mingfeima @malfet @jgong5 @r-barnes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132434
Approved by: https://github.com/jgong5
diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp
index eae9f8b..9d5575e 100644
--- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp
+++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp
@@ -24,6 +24,16 @@
// out = val * a + b
// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
// take b as a scalar pointer.
+#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
+template <typename T1, typename T2>
+inline void _scale_attn_mask_fusion_kernel(
+ T1* a,
+ T2* b,
+ const int& size,
+ T1* out,
+ T1& val,
+ bool is_b_stride_zero) {
+#else
template <bool is_b_stride_zero, typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
@@ -31,6 +41,7 @@
const int& size,
T1* out,
T1& val) {
+#endif
const auto vec_size1 = at::vec::Vectorized<T1>::size();
const auto vec_size2 = at::vec::Vectorized<T2>::size();
constexpr int64_t T1_n =
@@ -41,7 +52,11 @@
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
at::vec::VectorizedN<T2, T2_n> b_n;
+#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
+ if (is_b_stride_zero) {
+#else
if constexpr(is_b_stride_zero) {
+#endif
b_n = at::vec::VectorizedN<T2, T2_n>((T1)b[0]);
} else {
b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
@@ -53,7 +68,11 @@
for (; i < size; i++) {
auto tmp0 = a[i];
T1 tmp1;
+#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
+ if (is_b_stride_zero) {
+#else
if constexpr(is_b_stride_zero) {
+#endif
tmp1 = (T1)b[0];
} else {
tmp1 = (T1)b[i];
@@ -342,23 +361,34 @@
// qk <- qk * scaling + attn_mask
if (has_attn_mask) {
for (int64_t row = 0; row < qBlockSize; ++row) {
- if (mStrideN == 0) {
- _scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
+#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
+ _scale_attn_mask_fusion_kernel(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
- (m + row) * mStrideM,
+ (m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
qk_data + row * kvBlockSize,
- scaling_factor);
- } else {
- _scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
- qk_data + row * kvBlockSize,
- mask_data + i * mStrideB + j * mStrideH +
- (m + row) * mStrideM + n,
- kvBlockSize,
- qk_data + row * kvBlockSize,
- scaling_factor);
- }
+ scaling_factor,
+ mStrideN == 0);
+#else
+ if (mStrideN == 0) {
+ _scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
+ qk_data + row * kvBlockSize,
+ mask_data + i * mStrideB + j * mStrideH +
+ (m + row) * mStrideM,
+ kvBlockSize,
+ qk_data + row * kvBlockSize,
+ scaling_factor);
+ } else {
+ _scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
+ qk_data + row * kvBlockSize,
+ mask_data + i * mStrideB + j * mStrideH +
+ (m + row) * mStrideM + n,
+ kvBlockSize,
+ qk_data + row * kvBlockSize,
+ scaling_factor);
+ }
+#endif
}
}
// Update coefficients with Softmax
@@ -617,23 +647,34 @@
if (has_attn_mask) {
accum_t one = accum_t(1);
for (const auto row : c10::irange(qBlockSize)) {
- if (mStrideN == 0) {
- _scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
+#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
+ _scale_attn_mask_fusion_kernel(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
- (m + row) * mStrideM,
+ (m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
attn_data + row * kvBlockSize,
- one);
- } else {
- _scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
- attn_data + row * kvBlockSize,
- mask_data + i * mStrideB + j * mStrideH +
- (m + row) * mStrideM + n,
- kvBlockSize,
- attn_data + row * kvBlockSize,
- one);
- }
+ one,
+ mStrideN == 0);
+#else
+ if (mStrideN == 0) {
+ _scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
+ attn_data + row * kvBlockSize,
+ mask_data + i * mStrideB + j * mStrideH +
+ (m + row) * mStrideM,
+ kvBlockSize,
+ attn_data + row * kvBlockSize,
+ one);
+ } else {
+ _scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
+ attn_data + row * kvBlockSize,
+ mask_data + i * mStrideB + j * mStrideH +
+ (m + row) * mStrideM + n,
+ kvBlockSize,
+ attn_data + row * kvBlockSize,
+ one);
+ }
+#endif
}
}
// restore self attention after softmax from logsumexp