blob: 7956d0c7f5f51fb5ffffbd797629f9d147b89310 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
namespace {
// out = val * a + b
// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
// take b as a scalar pointer.
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
template <typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T2* b,
const int& size,
T1* out,
T1& val,
bool is_b_stride_zero) {
#else
template <bool is_b_stride_zero, typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T2* b,
const int& size,
T1* out,
T1& val) {
#endif
const auto vec_size1 = at::vec::Vectorized<T1>::size();
const auto vec_size2 = at::vec::Vectorized<T2>::size();
constexpr int64_t T1_n =
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
constexpr int64_t T2_n = 1;
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
int64_t i = 0;
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
at::vec::VectorizedN<T2, T2_n> b_n;
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
if (is_b_stride_zero) {
#else
if constexpr(is_b_stride_zero) {
#endif
b_n = at::vec::VectorizedN<T2, T2_n>((T1)b[0]);
} else {
b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
}
auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
auto res = a_n * vec_scale + b_n_convert;
res.store(out + i);
}
for (; i < size; i++) {
auto tmp0 = a[i];
T1 tmp1;
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
if (is_b_stride_zero) {
#else
if constexpr(is_b_stride_zero) {
#endif
tmp1 = (T1)b[0];
} else {
tmp1 = (T1)b[i];
}
out[i] = tmp0 * val + tmp1;
}
}
// 1) out = exp(a - val)
// 2) val = sum(out)
template <typename T1, typename T2>
inline void _exp_reduce_sum_fusion_kernel(
T1* a,
const int& size,
T2* out,
T1& val) {
auto vec_size = vec::Vectorized<T1>::size();
auto vec_max = vec::Vectorized<T1>(val);
T1 tmp_sum = 0;
auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = tmp0 - vec_max;
auto tmp2 = tmp1.exp_u20();
vec_tmp_sum += tmp2;
_store(out + i, tmp2);
}
tmp_sum = vec::vec_reduce_all<T1>(
[](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) {
return x + y;
},
vec_tmp_sum);
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 - val;
auto tmp2 = exp(tmp1);
tmp_sum += tmp2;
out[i] = tmp2;
}
val = tmp_sum;
}
// 1) out = a * scale
// 2) max = max(out)
template <typename scalar_t>
inline void _mul_reduce_max_fusion_kernel(
const scalar_t* a,
const scalar_t& scale,
const int& size,
scalar_t* out,
scalar_t& max) {
auto vec_size = vec::Vectorized<scalar_t>::size();
auto vec_scale = vec::Vectorized<scalar_t>(scale);
scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
auto tmp1 = tmp0 * vec_scale;
vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
_store(out + i, tmp1);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 * scale;
tmp_max = std::max(tmp_max, tmp1);
out[i] = tmp1;
}
max = std::max(
tmp_max,
vec::vec_reduce_all<scalar_t>(
[](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
return vec::maximum(x, y);
},
vec_tmp_max));
}
template <typename scalar_t>
static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
TORCH_CHECK(ptr2 == nullptr);
return ptr;
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
return ptr2;
}
template <typename scalar_t>
inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
using Vec = Vectorized<scalar_t>;
Vec data_vec = Vec(val);
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
data_vec.store(data + d);
}
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
# pragma unroll
#endif
for (; d < size; d++) {
data[d] = val;
}
}
void reshape_attn_mask_to_4d(
Tensor& attn_mask,
int64_t batchSize,
int64_t num_head,
int64_t qSize,
int64_t kvSize) {
// Support mask shapes:
// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
// Guaranteed in check_attn_mask_shape
int64_t attn_mask_size_0 = 1;
int64_t attn_mask_size_1 = 1;
if (attn_mask.dim() == 4) {
if (attn_mask.size(0) == batchSize) {
attn_mask_size_0 = batchSize;
}
if (attn_mask.size(1) == num_head) {
attn_mask_size_1 = num_head;
}
}
attn_mask = attn_mask
.view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)})
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
}
template <typename scalar_t>
inline void copy_value_with_pad(
const scalar_t* value_ptr,
scalar_t* dst_ptr,
int64_t rows,
int64_t cols,
int64_t prows,
int64_t pcols,
int64_t ldi) {
auto vec_size = at::vec::Vectorized<scalar_t>::size();
int64_t i = 0;
for (; i < rows; i++) {
int64_t j = 0;
for (; j < cols - (cols % vec_size); j += vec_size) {
auto vec_v =
at::vec::Vectorized<scalar_t>::loadu(value_ptr + i * ldi + j);
vec_v.store(dst_ptr + i * pcols + j);
}
if (j < cols) {
auto vec_v = at::vec::Vectorized<scalar_t>::loadu(
value_ptr + i * ldi + j, cols - j);
vec_v.store(dst_ptr + i * pcols + j, cols - j);
}
// col padding
auto psize = pcols - cols;
if (psize > 0) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t pj = 0;
for (; pj < psize - (psize % vec_size); pj += vec_size) {
zero_vec.store(dst_ptr + i * pcols + cols + pj);
}
if (pj < psize) {
zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
}
}
}
// row padding
for (; i < prows; i++) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero_vec.store(dst_ptr + i * pcols + j);
}
if (j < pcols) {
zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
}
}
}
template <typename scalar_t>
inline void pad_remain_row_col_zero(
scalar_t* value_ptr,
int rows,
int cols,
int prows,
int pcols,
int ldi) {
auto psize = pcols - cols;
if (psize == 0 && prows == rows) {
return;
}
auto vec_size = at::vec::Vectorized<scalar_t>::size();
auto zero = at::vec::Vectorized<scalar_t>(0);
if (psize > 0) {
for (int i = 0; i < rows; i++) {
int j = 0;
for (; j < psize - (psize % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + cols + j);
}
if (j < psize) {
zero.store(value_ptr + i * ldi + cols + j, psize - j);
}
}
}
for (int i = rows; i < prows; i++) {
int j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + j);
}
if (j < pcols) {
zero.store(value_ptr + i * ldi + j, pcols - j);
}
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size, bool with_pack=false>
void cpu_flash_attention(
const Tensor& output,
const Tensor& logsumexp,
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
at::Tensor query = q.transpose(1, 2);
at::Tensor key = k.transpose(1, 2);
at::Tensor value = v.transpose(1, 2);
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
}
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(has_attn_mask && attn_mask.value().size(0) > 1)
? attn_mask.value().stride(0)
: 0;
int64_t mStrideH =
(has_attn_mask && attn_mask.value().size(1) > 1)
? attn_mask.value().stride(1)
: 0;
int64_t mStrideM =
(has_attn_mask && attn_mask.value().size(2) > 1)
? attn_mask.value().stride(2)
: 0;
int64_t mStrideN =
(has_attn_mask && attn_mask.value().size(3) > 1)
? attn_mask.value().stride(3)
: 0;
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// Whether pack is needed
bool need_pack = false;
// Block size of packing B matrix
int64_t packb_size = 64;
// Use packb_size due to the limitation:
// oneDNN pack only supports output leading dimention being one of (16, 32, 48, 64)
// For instance,
// for q @ k.T [qSplitSize, headSize] * [headSize, kvSplitSize] = [qSplitSize, kvSplitSize],
// we need to split kvSplitSize with packb_size for packing k.T,
// for (q @ k.T) @ v [qSplitSize, kvSplitSize] x [kvSplitSize, headSize] -> [qSplitSize, headSize],
// we need to split headSize with packb_size for packing v
// TODO Simplify the check when oneDNN supports fused pack with transpose and has better performance
if (with_pack) {
need_pack = num_head >= 4 && headSize % packb_size == 0 && kvSize >= packb_size;
if (need_pack) {
float pack_size = batchSize * num_head * kvSize * headSize / 1024;
float gemm_size_per_thread =
(batchSize * num_head * qSlice + num_thread - 1) / num_thread *
qSplitSize * (is_causal ? qSize : kvSize) * headSize / 1024;
float gsize = gemm_size_per_thread / pack_size;
// When the number of gemm is much greater than the number of pack,
// the pack and padding overhead can be overlaped.
if (pack_size < 2688) {
need_pack = gsize >= 36 || (gsize >= 24 && headSize > packb_size);
} else if (pack_size < 16384) {
need_pack = gsize >= (is_causal ? 54 : 52);
} else {
need_pack = gsize >= (is_causal ? 54 : 40);
}
}
}
int64_t rHeadSize = need_pack ? (headSize + packb_size - 1) / packb_size * packb_size : headSize;
int64_t rkvSplitSize = need_pack ? (kvSplitSize + packb_size - 1) / packb_size * packb_size : kvSplitSize;
int64_t rkvTail = need_pack ? (kvTail + packb_size - 1) / packb_size * packb_size : kvTail;
int64_t rkvSize = kv_split_size > kvSize ? rkvTail : rkvSplitSize * kvSlice + rkvTail;
// oneDNN pack does not support odd K now, we need also pad odd K
bool headSize_even = headSize % 2 == 0;
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* qk */ qSplitSize * rkvSplitSize +
/* qk_max */ qSplitSize +
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * rHeadSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
at::Tensor buf_reduced = at::empty(
{num_thread,
qSplitSize,
is_reduced_type ? ekvSplitSize : 0},
query.options());
// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
mask_t* mask_data = has_attn_mask
? attn_mask.value().data_ptr<mask_t>()
: nullptr;
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
// Buffer to store padding query
scalar_t* query_padding_ptr = nullptr;
std::unique_ptr<scalar_t[]> query_padding_data;
if (!headSize_even && need_pack) {
query_padding_data = std::make_unique<scalar_t[]>(num_thread * qSplitSize * eheadSize);
query_padding_ptr = query_padding_data.get();
}
// Buffer to store Key and Value after transforms
scalar_t* key_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> key_reorder_data;
scalar_t* value_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> value_reorder_data;
int kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
if (need_pack) {
key_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * eheadSize * rkvSize);
key_reorder_ptr = key_reorder_data.get();
value_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * kv_padding_size * rHeadSize);
value_reorder_ptr = value_reorder_data.get();
}
// Reorder K, V
if (need_pack) {
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, l = 0, n = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
std::unique_ptr<scalar_t[]> transpose_buffer = std::make_unique<scalar_t[]>(eheadSize * packb_size);
scalar_t* transpose_buffer_ptr = transpose_buffer.get();
std::unique_ptr<scalar_t[]> v_copy_buffer = std::make_unique<scalar_t[]>(ekvSplitSize * packb_size);
scalar_t* v_copy_buffer_ptr = v_copy_buffer.get();
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
n = l * kvSplitSize;
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1;
// Split kvSplitSize with packb_size
// [kvSplitSize, headSize] -> [div_up(kvSplitSize, packb_size), packb_size, headSize]
// Transpose [packb_size, headSize] -> [headSize, packb_size]
// Pack transposed buffer
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
bool tail = kvBlockSize - b < packb_size;
// TODO Use fused pack with transpose support when oneDNN supports such usage
utils::transpose<uint16_t>(
tail ? kvBlockSize - b : packb_size,
headSize,
/* src_ptr */
reinterpret_cast<const uint16_t*>(
k_data + i * kStrideB + j * kStrideH + n * kStrideN +
b * kStrideN),
/* ld_src */ kStrideN,
/* dst */ reinterpret_cast<uint16_t*>(transpose_buffer_ptr),
/* ld_dst */ packb_size);
// Pad [headSize, x] -> [eheadSize, x]
if (!headSize_even) {
pad_remain_row_col_zero<scalar_t>(
transpose_buffer_ptr,
headSize,
packb_size,
eheadSize,
packb_size,
packb_size);
}
// Pack
cpublas::pack(
/* K */ eheadSize,
/* N */ packb_size,
/* ld_in */ packb_size,
/* ld_out */ packb_size,
/* dt_in */ dtype,
/* dt_out */ dtype,
transpose_buffer_ptr,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize);
}
// Split headSize with packb_size
// [kvSplitSize, headSize] -> [kvSplitSize, div_up(headSize, packb_size), packb_size]
for (int64_t b = 0; b < headSize; b += packb_size) {
// Do copy due to the limitation of input_ld of oneDNN pack:
// Regarding packing [K, N], only input_ld == N is supported
// TODO: remove the copy when pack supports input_ld >= N
copy_value_with_pad<scalar_t>(
v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
v_copy_buffer_ptr,
kvBlockSize,
(headSize - b < packb_size) ? headSize - b : packb_size,
ekvBlockSize,
packb_size,
vStrideN);
cpublas::pack(
ekvBlockSize,
packb_size,
packb_size,
packb_size,
dtype,
dtype,
v_copy_buffer_ptr,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + n * rHeadSize +
ekvBlockSize * b);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
}
});
}
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* qk_data = buf_ptr;
accum_t* qk_max_data = qk_data + qSplitSize * rkvSplitSize;
accum_t* qk_sum_data = qk_max_data + qSplitSize;
accum_t* dst_data = qk_sum_data + qSplitSize;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr;
scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
? query_padding_ptr + ompIdx * qSplitSize * eheadSize
: nullptr;
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
int64_t m = k * qSplitSize;
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// Initialize max and sum
fill_stub(qk_max_data,
-std::numeric_limits<accum_t>::infinity(), qBlockSize);
fill_stub(qk_sum_data,
static_cast<accum_t>(0), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
if (!headSize_even && need_pack) {
// Pad query if headSize is not even
// [qBlockSize, headSize] -> [qBlockSize, eheadSize]
copy_value_with_pad<scalar_t>(
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
query_t_padding_ptr,
qBlockSize,
headSize,
qBlockSize,
eheadSize,
qStrideM
);
}
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize;
int64_t rkvBlockSize = kvBlockSize == kvSplitSize ? rkvSplitSize : rkvTail;
// Calculate scale * q @ k.T
if (need_pack) {
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
eheadSize,
headSize_even ? qStrideM : eheadSize,
packb_size,
rkvBlockSize,
1.f,
0.f,
!headSize_even
? query_t_padding_ptr
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize,
qk_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
static_cast<accum_t>(1),
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
}
// Apply causal mask, fill unused with -inf
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = qk_data + row * rkvBlockSize;
fill_stub(row_ptr + last_col + 1,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
}
}
// Update attention weights with attention mask
// And apply scaling factor
// qk <- qk * scaling + attn_mask
if (has_attn_mask) {
for (int64_t row = 0; row < qBlockSize; ++row) {
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
}
#endif
}
}
// Update coefficients with Softmax
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
for (int64_t row = 0; row < qBlockSize; ++row) {
if (has_attn_mask) {
// max per row
tmp_max = at::vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
qk_data + row * rkvBlockSize,
kvBlockSize);
} else {
// apply scaling factor and max per row in fusion
_mul_reduce_max_fusion_kernel(
qk_data + row * rkvBlockSize,
scaling_factor,
kvBlockSize,
qk_data + row * rkvBlockSize,
tmp_max);
}
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
// to avoid `nan = exp2f(-inf - (-inf))`
fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
static_cast<scalar_t>(0), kvBlockSize);
} else {
tmp_sum = tmp_max;
// qk <- exp(qk - max) and sum per row
_exp_reduce_sum_fusion_kernel(
qk_data + row * rkvBlockSize, kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
tmp_sum);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
// sum[row] <- sum + exp_tmp * sum[row]
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
// max[row] <- max
qk_max_data[row] = tmp_max;
// dst <- dst * exp_tmp
if (n > 0) {
vec::map<accum_t>(
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
dst_data + row * rHeadSize,
dst_data + row * rHeadSize,
headSize);
}
}
if (need_pack && kvBlockSize % 2 != 0) {
// Pad: [qSplitSize,kvSplitSize] -> [qSplitSize,kvSplitSize + 1]
*(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0);
}
}
// Calculate Softmax(q @ k.T) @ v
if (need_pack) {
int64_t psize = n / kvSplitSize * ekvSplitSize;
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < headSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
ekvBlockSize,
ekvBlockSize,
packb_size,
rHeadSize,
1.0,
n == 0 ? 0.f : 1.f,
qk_reduced_data,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + psize * rHeadSize +
b * ekvBlockSize,
dst_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH +
n * vStrideN,
vStrideN,
conditional_data_ptr(qk_data, qk_reduced_data),
kvBlockSize,
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
dst_data,
headSize);
}
}
// dst <- dst / sum[row]
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
// Row sums for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
accum_t sum_reciprocal = 1 / qk_sum_data[row];
vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
dst_data + row * rHeadSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] = qk_max_data[row]
+ std::log(qk_sum_data[row]);
}
// Move to the next query
data_index_step(i, batchSize, j, num_head, k, qSlice);
}
});
if (need_pack) {
cpublas::brgemm_release();
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention_backward(
const at::Tensor& grad_q,
const at::Tensor& grad_k,
const at::Tensor& grad_v,
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size");
// Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
}
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = out.stride(0);
int64_t oStrideM = out.stride(1);
int64_t oStrideH = out.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(has_attn_mask && attn_mask.value().size(0) > 1)
? attn_mask.value().stride(0)
: 0;
int64_t mStrideH =
(has_attn_mask && attn_mask.value().size(1) > 1)
? attn_mask.value().stride(1)
: 0;
int64_t mStrideM =
(has_attn_mask && attn_mask.value().size(2) > 1)
? attn_mask.value().stride(2)
: 0;
int64_t mStrideN =
(has_attn_mask && attn_mask.value().size(3) > 1)
? attn_mask.value().stride(3)
: 0;
int64_t grad_qStrideB = grad_q.stride(0);
int64_t grad_qStrideM = grad_q.stride(1);
int64_t grad_qStrideH = grad_q.stride(2);
int64_t grad_kStrideB = grad_k.stride(0);
int64_t grad_kStrideN = grad_k.stride(1);
int64_t grad_kStrideH = grad_k.stride(2);
int64_t grad_vStrideB = grad_v.stride(0);
int64_t grad_vStrideN = grad_v.stride(1);
int64_t grad_vStrideH = grad_v.stride(2);
int64_t grad_oStrideB = grad_out.stride(0);
int64_t grad_oStrideM = grad_out.stride(1);
int64_t grad_oStrideH = grad_out.stride(2);
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* attn */ qSplitSize * kvSplitSize +
/* grad_attn */ qSplitSize * kvSplitSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
// allocate per thread temp buf_reduced (scalar type)
// buf2 is only needed for bfloat16 and float16
int64_t size_per_thread_reduced =
/* attn_reduced */ qSplitSize * kvSplitSize +
/* grad_attn_reduced */ qSplitSize * kvSplitSize;
at::Tensor buf_reduced = at::empty({num_thread, is_reduced_type ? size_per_thread_reduced : 0}, query.options());
scalar_t* grad_q_data = grad_q.data_ptr<scalar_t>();
scalar_t* grad_k_data = grad_k.data_ptr<scalar_t>();
scalar_t* grad_v_data = grad_v.data_ptr<scalar_t>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
mask_t* mask_data = has_attn_mask
? attn_mask.value().data_ptr<mask_t>()
: nullptr;
const scalar_t* out_data = out.const_data_ptr<scalar_t>();
const accum_t* lse_data = logsumexp.const_data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0;
data_index_init(begin, i, batchSize, j, num_head);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* attn_data = buf_ptr;
accum_t* grad_attn_data = attn_data + qSplitSize * kvSplitSize;
scalar_t* buf_reduced_ptr = is_reduced_type ? buf_reduced_data + ompIdx * size_per_thread_reduced : nullptr;
scalar_t* attn_reduced_data = is_reduced_type ? buf_reduced_ptr : nullptr;
scalar_t* grad_attn_reduced_data = is_reduced_type ? attn_reduced_data + qSplitSize * kvSplitSize : nullptr;
at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype));
accum_t* dsum_data = dsum.data_ptr<accum_t>();
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
// rowsum of grad_out * out
for (int64_t m = 0; m < qSize; m += qSplitSize) {
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// dsum <- rowsum(grad_out * out)
for (const auto row : c10::irange(qBlockSize)) {
*(dsum_data + row) = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM,
out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM,
headSize);
}
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// attn <- scale * q @ k.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
attn_data,
kvBlockSize);
// attn <- attn + mask
if (has_attn_mask) {
accum_t one = accum_t(1);
for (const auto row : c10::irange(qBlockSize)) {
#if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
attn_data + row * kvBlockSize,
one,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
attn_data + row * kvBlockSize,
one);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
attn_data + row * kvBlockSize,
one);
}
#endif
}
}
// restore self attention after softmax from logsumexp
// attn <- exp(attn - normalizer)
for (const auto row : c10::irange(qBlockSize)) {
accum_t normalizer = lse_data[i * lStrideB + j * lStrideH + (m + row) * lStrideM];
vec::map<accum_t>(
[normalizer](Vec x) { return (x - Vec(normalizer)).exp(); },
attn_data + row * kvBlockSize,
attn_data + row * kvBlockSize,
kvBlockSize);
}
// Apply causal mask, filled unused with 0
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = attn_data + row * kvBlockSize;
fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
}
}
#ifdef _MSC_VER
if (is_reduced_type) {
#else
if constexpr (is_reduced_type) {
#endif
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
attn_data + row * kvBlockSize,
attn_reduced_data + row * kvBlockSize,
kvBlockSize);
}
}
// grad_v <- grad_v + attn.T @ grad_out
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::Transpose,
headSize,
kvBlockSize,
qBlockSize,
static_cast<accum_t>(1),
grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
m * grad_oStrideM,
grad_oStrideM,
conditional_data_ptr(attn_data, attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_v_data + i * grad_vStrideB + j * grad_vStrideH +
n * grad_vStrideN,
grad_vStrideN);
// grad_attn <- grad_out @ v.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH +
n * vStrideN,
vStrideN,
grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
m * grad_oStrideM,
grad_oStrideM,
static_cast<accum_t>(0),
grad_attn_data,
kvBlockSize);
// grad_attn <- attn * (grad_attn - dsum)
for (const auto row : c10::irange(qBlockSize)) {
accum_t d = *(dsum_data + row);
vec::map2<accum_t>(
[d](Vec attn, Vec grad_attn) { return attn * (grad_attn - Vec(d)); },
grad_attn_data + row * kvBlockSize,
attn_data + row * kvBlockSize,
grad_attn_data + row * kvBlockSize,
kvBlockSize);
}
#ifdef _MSC_VER
if (is_reduced_type) {
#else
if constexpr (is_reduced_type) {
#endif
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
grad_attn_data + row * kvBlockSize,
grad_attn_reduced_data + row * kvBlockSize,
kvBlockSize);
}
}
// grad_q <- grad_q + scale * grad_attn @ k
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_q_data + i * grad_qStrideB + j * grad_qStrideH +
m * grad_qStrideM,
grad_qStrideM);
// grad_k <- grad_k + scale * grad_attn.T @ q
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::Transpose,
headSize,
kvBlockSize,
qBlockSize,
scaling_factor,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_k_data + i * grad_kStrideB + j * grad_kStrideH +
n * grad_kStrideN,
grad_kStrideN);
}
}
// Move to the next query
data_index_step(i, batchSize, j, num_head);
}
});
}
#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Bool, mask_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Float, mask_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Double, mask_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Half, mask_t, __VA_ARGS__))
#define FLASH_ATTENTION_KERNEL(FNAME, PACK, TYPE1, TYPE2, SEQ1, SEQ2, ...) \
if (PACK) { \
FNAME<TYPE1, TYPE2, SEQ1, SEQ2, true>(__VA_ARGS__); \
} else { \
FNAME<TYPE1, TYPE2, SEQ1, SEQ2>(__VA_ARGS__); \
}
void flash_attention_kernel_impl(
const Tensor& output,
const Tensor& logsumexp,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
auto q_seq_len = query.size(2);
// When q_seq_len and k_seq_len are long enough,
// cpu_flash_attention with pack has better performance.
bool could_pack = (query.scalar_type() == kHalf && cpublas::need_pack(kHalf));
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] {
if (!attn_mask.has_value()) {
if (q_seq_len >= 768) {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 256, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 64, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 32, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
}
} else {
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() {
if (q_seq_len >= 768) {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 256, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 64, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
} else {
FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 32, 512,
output, logsumexp, query, key, value,
dropout_p, is_causal, attn_mask, scale);
}
});
}
});
}
#undef FLASH_ATTENTION_KERNEL
void flash_attention_backward_kernel_impl(
const at::Tensor& grad_q,
const at::Tensor& grad_k,
const at::Tensor& grad_v,
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
// make sure grad_out has no zero strides (broadcasted dimensions)
// since we are going to call gemm next
// zero stride in leading dimension would lead to slow impl for gemm
auto grad_out_contig = grad_out.contiguous();
auto q_seq_len = query.size(1);
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention_backward", [&] {
if (!attn_mask.has_value() || !attn_mask.value().defined()) {
using accum_t = at::opmath_type<scalar_t>;
if (q_seq_len >= 768) {
cpu_flash_attention_backward<scalar_t, accum_t, 256, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention_backward<scalar_t, accum_t, 64, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
} else {
cpu_flash_attention_backward<scalar_t, accum_t, 32, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
}
} else {
AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask_backward", [&]() {
if (q_seq_len >= 768) {
cpu_flash_attention_backward<scalar_t, mask_t, 256, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention_backward<scalar_t, mask_t, 64, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
} else {
cpu_flash_attention_backward<scalar_t, mask_t, 32, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
dropout_p, is_causal, attn_mask, scale);
}
});
}
});
}
} // anonymous namespace
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);
} // at::native