blob: 475e27cb94e3487bd2632b0affbeb7fe4e6d5b66 [file] [log] [blame]
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/zmath.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h>
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <arm_neon.h>
namespace at::native::blas_impl {
void fp16_gemv_notrans(
const int m,
const int n,
const float alpha,
const float16_t* a,
const int lda,
const float16_t* x,
const int incx,
const float beta,
float16_t* y,
const int incy);
void fp16_gemv_trans(
const int m,
const int n,
const float alpha,
const float16_t* a,
const int lda,
const float16_t* x,
const int incx,
const float beta,
float16_t* y,
const int incy);
float fp16_dot_with_fp32_arith(
const float16_t* x,
const float16_t* a,
int64_t len);
float bf16_dot_with_fp32_arith(
const at::BFloat16* x,
const at::BFloat16* a,
int64_t len);
}
#endif
namespace at::native {
namespace cpublas {
namespace {
template <typename scalar_t, typename opmath_t>
void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) {
if (alpha == opmath_t(1)) {
return; // identity
}
if (alpha == opmath_t(0)) {
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
a[j * lda + i] = scalar_t(0);
}
}
return;
}
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
a[j * lda + i] *= alpha;
}
}
}
template <typename Func>
auto sum(int64_t N, Func f) {
constexpr int ilp_factor = 4;
using acc_t = decltype(f(0));
// Calculate independent partial sums then add together at the end
std::array<acc_t, ilp_factor> partial_sums{};
int64_t i = 0;
for (; i + ilp_factor <= N; i += ilp_factor) {
c10::ForcedUnroll<ilp_factor>{}([&](int k) {
partial_sums[k] += f(i + k);
});
}
for (; i < N; ++i) {
partial_sums[0] += f(i);
}
for (int k = 1; k < ilp_factor; ++k) {
partial_sums[0] += partial_sums[k];
}
return partial_sums[0];
}
template <typename scalar_t, typename opmath_t>
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
// c += alpha * (a @ b)
for (const auto l : c10::irange(k)) {
for (const auto j : c10::irange(n)) {
opmath_t val = b[l + j * ldb] * alpha;
int64_t i_m = m / 4;
for (const auto i_i : c10::irange(i_m)) {
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
}
int64_t i = i_m * 4;
for (; i < m; i++)
c[j * ldc + i] += a[i + l * lda] * val;
}
}
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c += alpha * (a @ b)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[l * lda + i]) *
static_cast<opmath_t>(b[j * ldb + l]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
const scalar_t *a_ = a;
for (const auto i : c10::irange(m)) {
const scalar_t *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a_[l]) : a_[l]) * static_cast<opmath_t>(b_[l]);
});
b_ += ldb;
if (beta == opmath_t(0)) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transb_impl(
TransposeType transb,
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
/* we expect pre-applied beta */
opmath_t* c,
int64_t ldc) {
// c += alpha * (a @ b.T)
for (const auto l : c10::irange(k)) {
for (const auto j : c10::irange(n)) {
opmath_t val = (transb == TransposeType::ConjTranspose ? conj_impl(b[j + l * ldb]) : b[j + l * ldb]) * alpha;
int64_t i_m = m / 4;
for (const auto i_i : c10::irange(i_m)) {
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
}
int64_t i = i_m * 4;
for (; i < m; i++)
c[j * ldc + i] += a[i + l * lda] * val;
}
}
}
template <typename scalar_t, typename opmath_t>
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
TransposeType transb,
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c, ldc);
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
TransposeType transb,
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// We need to calculate full-precision dot products for correctness;
// users notice error accumulation with reduced-width types (e.g.,
// https://github.com/pytorch/pytorch/issues/95125 and
// https://github.com/pytorch/pytorch/issues/83863, which were filed
// when we used gemm_transb_impl naively, accumulating into
// float16/bfloat16). The straightforward way to do this is to use
// the vector dot column algorithm anyway, but this gives terrible
// performance because of the non-contiguous matrix
// access. Therefore, we instead elect to allocate temporary space
// to hold the output at higher-precision so that we can accumulate
// into it using the above cache-friendly "load one vector element,
// FMA it with an entire matrix row into the entire result vector"
// algorithm instead.
const auto c_size = m * n;
auto c_accum = std::make_unique<opmath_t[]>(c_size);
if (beta == 1) {
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c_accum[j * m + i] = c[j * ldc + i];
}
}
} else if (beta == 0) {
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c_accum[j * m + i] = 0;
}
}
} else {
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c_accum[j * m + i] = beta * c[j * ldc + i];
}
}
}
gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c_accum.get(), m);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c[j * ldc + i] = c_accum[j * m + i];
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transab_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
// c = beta * c + alpha * (a.T @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a[i * lda + l]) : a[i * lda + l]) *
static_cast<opmath_t>(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
#if defined(__aarch64__) && !defined(C10_MOBILE)
template <>
void gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
float alpha,
const at::Half* a,
int64_t lda,
const at::Half* b,
int64_t ldb,
float beta,
at::Half* c,
int64_t ldc) {
// c += alpha * (a @ b)
if (n == 1 && beta == 0.0) {
at::native::blas_impl::fp16_gemv_notrans(m, k, alpha, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, beta, reinterpret_cast<float16_t*>(c), 1);
return;
}
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> float {
return float(c10::detail::fp16_from_bits(a[l * lda + i].x)) *
float(c10::detail::fp16_from_bits(b[j * ldb + l].x));
});
if (beta == 0) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}
static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
return at::native::blas_impl::fp16_dot_with_fp32_arith(
reinterpret_cast<const float16_t*>(a),
reinterpret_cast<const float16_t*>(b),
len);
}
static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) {
return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len);
}
template <>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::Half *a, int64_t lda,
const at::Half *b, int64_t ldb,
float beta,
at::Half *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
if (n == 1 && beta == 0.0) {
at::native::blas_impl::fp16_gemv_trans(k, m, alpha, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, beta, reinterpret_cast<float16_t*>(c), 1);
return;
}
parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
const auto *a_ = a + begin * lda;
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
});
}
template <>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
float beta,
at::BFloat16 *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
const auto *a_ = a + begin * lda;
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
});
}
#endif
template <typename scalar_t, typename opmath_t>
void gemm_core_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa != TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa == TransposeType::NoTranspose &&
transb != TransposeType::NoTranspose) {
gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else {
gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
void cpublas_gemm_impl(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc) {
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
using opmath_t = at::opmath_type<scalar_t>;
gemm_core_(
transa, transb, m, n, k,
alpha.to<opmath_t>(),
static_cast<const scalar_t *>(a), lda,
static_cast<const scalar_t *>(b), ldb,
beta.to<opmath_t>(),
static_cast<scalar_t *>(c), ldc);
});
}
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
if (type == at::kBool) {
auto a = _a.to<bool>();
auto x = static_cast<const bool *>(_x);
auto y = static_cast<bool *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] |= a & x[i*incx];
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
auto a = _a.to<opmath_t>();
auto x = static_cast<const scalar_t *>(_x);
auto y = static_cast<scalar_t *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] += a*x[i*incx];
});
}
}
void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t incx, void *_y, int64_t incy){
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::kComplexHalf, at::kHalf, at::kBFloat16, at::kBool, type, "cpublas_copy_impl",
[&] {
auto x = static_cast<const scalar_t *>(_x);
auto y = static_cast<scalar_t *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] = x[i*incx];
});
}
}} // namespace cpublas::(anonymous)
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl);
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl);
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl);
} // namespace at::native