blob: 554eb1989efde431fc4e5d52701b530379660870 [file] [log] [blame]
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/zmath.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h>
namespace at::native {
namespace cpublas {
namespace {
template <typename scalar_t, typename opmath_t>
void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) {
if (alpha == opmath_t(1)) {
return; // identity
}
if (alpha == opmath_t(0)) {
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
a[j * lda + i] = scalar_t(0);
}
}
return;
}
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
a[j * lda + i] *= alpha;
}
}
}
template <typename Func>
auto sum(int64_t N, Func f) {
constexpr int ilp_factor = 4;
using acc_t = decltype(f(0));
// Calculate independent partial sums then add together at the end
std::array<acc_t, ilp_factor> partial_sums{};
int64_t i = 0;
for (; i + ilp_factor <= N; i += ilp_factor) {
c10::ForcedUnroll<ilp_factor>{}([&](int k) {
partial_sums[k] += f(i + k);
});
}
for (; i < N; ++i) {
partial_sums[0] += f(i);
}
for (int k = 1; k < ilp_factor; ++k) {
partial_sums[0] += partial_sums[k];
}
return partial_sums[0];
}
template <typename scalar_t, typename opmath_t>
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
// c += alpha * (a @ b)
for (const auto l : c10::irange(k)) {
for (const auto j : c10::irange(n)) {
opmath_t val = b[l + j * ldb] * alpha;
int64_t i_m = m / 4;
for (const auto i_i : c10::irange(i_m)) {
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
}
int64_t i = i_m * 4;
for (; i < m; i++)
c[j * ldc + i] += a[i + l * lda] * val;
}
}
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c += alpha * (a @ b)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[l * lda + i]) *
static_cast<opmath_t>(b[j * ldb + l]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
const scalar_t *a_ = a;
for (const auto i : c10::irange(m)) {
const scalar_t *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a_[l]) : a_[l]) * static_cast<opmath_t>(b_[l]);
});
b_ += ldb;
if (beta == opmath_t(0)) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
}
template <typename scalar_t, typename opmath_t>
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
TransposeType transb,
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
// c += alpha * (a @ b.T)
for (const auto l : c10::irange(k)) {
for (const auto j : c10::irange(n)) {
opmath_t val = (transb == TransposeType::ConjTranspose ? conj_impl(b[j + l * ldb]) : b[j + l * ldb]) * alpha;
int64_t i_m = m / 4;
for (const auto i_i : c10::irange(i_m)) {
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
}
int64_t i = i_m * 4;
for (; i < m; i++)
c[j * ldc + i] += a[i + l * lda] * val;
}
}
}
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(
TransposeType transb,
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
scalar_t* c,
int64_t ldc) {
// c += alpha * (a @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[l * lda + i]) *
static_cast<opmath_t>(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_transab_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
// c = beta * c + alpha * (a.T @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a[i * lda + l]) : a[i * lda + l]) *
static_cast<opmath_t>(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]);
});
if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
}
}
}
template <typename scalar_t, typename opmath_t>
void gemm_core_(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa != TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa == TransposeType::NoTranspose &&
transb != TransposeType::NoTranspose) {
gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else {
gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
void cpublas_gemm_impl(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc) {
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
using opmath_t = at::opmath_type<scalar_t>;
gemm_core_(
transa, transb, m, n, k,
alpha.to<opmath_t>(),
static_cast<const scalar_t *>(a), lda,
static_cast<const scalar_t *>(b), ldb,
beta.to<opmath_t>(),
static_cast<scalar_t *>(c), ldc);
});
}
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
if (type == at::kBool) {
auto a = _a.to<bool>();
auto x = static_cast<const bool *>(_x);
auto y = static_cast<bool *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] |= a & x[i*incx];
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
auto a = _a.to<opmath_t>();
auto x = static_cast<const scalar_t *>(_x);
auto y = static_cast<scalar_t *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] += a*x[i*incx];
});
}
}
void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t incx, void *_y, int64_t incy){
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::kComplexHalf, at::kHalf, at::kBFloat16, at::kBool, type, "cpublas_copy_impl",
[&] {
auto x = static_cast<const scalar_t *>(_x);
auto y = static_cast<scalar_t *>(_y);
int64_t i;
for(i = 0; i < n; i++)
y[i*incy] = x[i*incx];
});
}
}} // namespace cpublas::(anonymous)
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl);
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl);
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl);
} // namespace at::native