blob: cfa362420f983ced5cf122809acc71aa81e84d5b [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/optimized/blas/BlasKernel.h>
#ifdef __aarch64__
#include <arm_neon.h>
#endif
using torch::executor::BFloat16;
namespace executorch {
namespace cpublas {
namespace internal {
#ifdef __aarch64__
static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
#ifdef __ARM_FEATURE_FMA
return vfmaq_f32(a, b, c);
#else
return vaddq_f32(a, vmulq_f32(b, c));
#endif
}
// The below reduce overload and fp16_dot_with_fp32_arith are adapted
// from llama.cpp's ggml_vec_dot_f32 and surrounding utility
// functions. See NOTE [ GGML Copyright Notice ] above for the
// required notice.
// We need the shift for reduce(), hence the extra constants.
static constexpr auto kF32ElementsPerIterationShift = 5;
static constexpr auto kF32ElementsPerIteration = 1
<< kF32ElementsPerIterationShift;
static_assert(kF32ElementsPerIteration == 32);
static constexpr auto kF32ElementsPerRegisterShift = 2;
static constexpr auto kF32ElementsPerRegister = 1
<< kF32ElementsPerRegisterShift;
static_assert(kF32ElementsPerRegister == 4);
static constexpr auto kF32RegisterPairsPerIteration = 4;
static constexpr auto kF32RegistersPerIteration =
kF32RegisterPairsPerIteration * 2;
static constexpr auto kF32RegistersPerIterationShift = 3;
static_assert(
kF32RegistersPerIteration ==
kF32ElementsPerIteration / kF32ElementsPerRegister);
static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
int offset = kF32RegistersPerIteration;
utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
[&offset, &x](auto idx) ET_INLINE_ATTRIBUTE {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
});
return vaddvq_f32(x[0]);
}
static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
int32x4_t shift = vdupq_n_s32(16);
return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
}
static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
}
#ifdef __ARM_FEATURE_BF16
static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
return vbfdotq_f32(a, b, c);
}
#endif
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
#ifdef __ARM_FEATURE_BF16
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
#else
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
#endif
}
static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t* tailSum,
int idx) {
const auto temp_vec1 =
vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
const auto temp_vec2 =
vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
}
template <typename T>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);
// First-tier tail fixup: make sure we handle workloads that can
// benefit from vectorization, but don't fit into our fully unrolled
// loop above.
float32x4_t tailSum = vdupq_n_f32(0);
const auto len_aligned_4 = len & ~3;
for (int j = len_aligned; j < len_aligned_4; j += 4) {
dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
}
auto reducedTail = vpaddq_f32(tailSum, tailSum);
reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
// Second-tier tail fixup: handle all workloads.
for (int j = len_aligned_4; j < len; ++j) {
reducedSum += vec1[j] * vec2[j];
}
return reducedSum;
}
float bf16_dot_with_fp32_arith(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
}
#endif
} // namespace internal
} // namespace cpublas
} // namespace executorch