blob: b00e2650243288d5c07770bd74c686abbf344430 [file] [log] [blame]
#include "caffe2/contrib/fakelowp/fp16_gemm_utils.h"
#include <fbgemm/FbgemmConvert.h>
#include <fbgemm/FbgemmFP16.h>
#include <glog/logging.h>
#include <immintrin.h>
#include "caffe2/core/context.h"
#include "caffe2/utils/math.h"
C10_DECLARE_bool(caffe2_fbgemm_fake_fp16_clamp);
namespace caffe2 {
// dimA(before transpose) = M x N, dimA (after transpose) = N x M.
void transpose(const float* A, std::vector<float>& A_trans, int M, int N) {
CAFFE_ENFORCE_EQ(M * N, A_trans.size());
fbgemm::transpose_simd(M, N, A, N, A_trans.data(), M);
}
void custom_fp16_gemm_with_trans(
const CBLAS_TRANSPOSE trans_A,
const CBLAS_TRANSPOSE trans_B,
const int m,
const int k,
const int n,
const float* A,
const float* B,
const float beta,
float* C,
const bool use_acc_fp16,
const bool use_temp_accumulator) {
switch (trans_A) {
case CblasNoTrans: {
switch (trans_B) {
case CblasNoTrans: {
// A * B
custom_fp16_gemm(
m, k, n, A, B, beta, C, use_acc_fp16, use_temp_accumulator);
break;
}
case CblasTrans: {
// A * B_trans
std::vector<float> B_trans(n * k);
transpose(B, B_trans, n, k);
custom_fp16_gemm(
m,
k,
n,
A,
B_trans.data(),
beta,
C,
use_acc_fp16,
use_temp_accumulator);
break;
}
default:
LOG(FATAL) << "Unexpected CBLAS_TRAnSPOSE for trans_B";
}
} break;
case CblasTrans: {
switch (trans_B) {
case CblasNoTrans: {
// A_trans * B
std::vector<float> A_trans(k * m);
transpose(A, A_trans, k, m);
custom_fp16_gemm(
m,
k,
n,
A_trans.data(),
B,
beta,
C,
use_acc_fp16,
use_temp_accumulator);
break;
}
case CblasTrans: {
// A_trans * B_trans
std::vector<float> A_trans(k * m);
std::vector<float> B_trans(n * k);
transpose(A, A_trans, k, m);
transpose(B, B_trans, n, k);
custom_fp16_gemm(
m,
k,
n,
A_trans.data(),
B_trans.data(),
beta,
C,
use_acc_fp16,
use_temp_accumulator);
break;
}
default:
LOG(FATAL) << "Unexpected CBLAS_TRAnSPOSE for trans_B";
}
} break;
default:
LOG(FATAL) << "Unexpected CBLAS_TRAnSPOSE for trans_A";
}
}
static inline __m256 clamp_subnormals(__m256 input, const float epsilon_) {
__m256 epsilon = _mm256_set1_ps(epsilon_);
__m256 nepsilon = _mm256_set1_ps(-epsilon_);
__m256 mask = _mm256_or_ps(
_mm256_cmp_ps(input, nepsilon, _CMP_LE_OS),
_mm256_cmp_ps(input, epsilon, _CMP_GE_OS));
return _mm256_and_ps(input, mask);
}
void custom_fp16_gemm(
const int m,
const int k,
const int n,
const float* A_fp16,
const float* B_fp16,
const float beta,
float* C,
const bool use_acc_fp16,
const bool use_temp_accumulator) {
#ifdef LOG_LEVEL_FOR_FBFCPACkEDACC16_PERFORmAnCE_LOG
clock_t begin = clock();
#endif
int C_size = m * n;
if (beta == 0) {
// In Caffe2 we often do a lazy initialization, which may contain NaNs in
// the float values. As a result, if beta is 0, we explicitly do a setzero.
memset(C, 0, C_size * sizeof(C[0]));
} else {
float beta_fp16 = fbgemm::cpu_half2float(fbgemm::cpu_float2half_rn(beta));
__m256 mBetaFp16 = _mm256_broadcast_ss(&beta_fp16);
int i = 0;
for (i = 0; i + 8 <= C_size; i += 8) {
__m256 mC = _mm256_loadu_ps(C + i);
mC = _mm256_mul_ps(mC, mBetaFp16);
_mm256_storeu_ps(C + i, mC);
}
for (; i < C_size; i++) {
C[i] = C[i] * beta_fp16;
}
}
// Encode the smallest normal number in float16
union epsilon_t {
float f;
uint32_t i;
};
union epsilon_t epsilon;
epsilon.i = 0x38800000u; // 1 / 16384
constexpr int VLEn = 8;
const int kb_max = 128;
for (int i = 0; i < m; i++) {
for (int l = 0; l < k; l += kb_max) {
int kb = std::min(kb_max, k - l);
for (int j = 0; j < n; j += VLEn) {
int nb = std::min(VLEn, n - j);
if (nb == VLEn) {
__m256 mC = _mm256_loadu_ps(C + i * n + j);
__m256 mC_temp = _mm256_setzero_ps();
for (int l2 = l; l2 < l + kb; l2++) {
__m256 mA_fp16 = _mm256_broadcast_ss(A_fp16 + i * k + l2);
__m256 mB_fp16 = _mm256_loadu_ps((B_fp16 + l2 * n + j));
if (use_acc_fp16) {
mA_fp16 = clamp_subnormals(mA_fp16, epsilon.f);
mB_fp16 = clamp_subnormals(mB_fp16, epsilon.f);
}
__m256 mAB = _mm256_mul_ps(mA_fp16, mB_fp16);
if (use_acc_fp16) {
__m256 mAB_fp16 = _mm256_cvtph_ps(_mm256_cvtps_ph(mAB, 0));
mAB_fp16 = clamp_subnormals(mAB_fp16, epsilon.f);
if (use_temp_accumulator) {
mC_temp = _mm256_add_ps(mC_temp, mAB_fp16);
mC_temp = _mm256_cvtph_ps(_mm256_cvtps_ph(mC_temp, 0));
} else {
mC = _mm256_add_ps(mC, mAB_fp16);
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
}
} else {
if (use_temp_accumulator) {
mC_temp = _mm256_add_ps(mC_temp, mAB);
} else {
mC = _mm256_add_ps(mC, mAB);
}
}
if (use_acc_fp16) {
mC = clamp_subnormals(mC, epsilon.f);
}
}
if (use_temp_accumulator) {
if (use_acc_fp16) {
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
mC = _mm256_add_ps(mC, mC_temp);
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
} else {
mC = _mm256_add_ps(mC, mC_temp);
}
}
_mm256_storeu_ps(C + i * n + j, mC);
} else {
__m256 mC_temp = _mm256_setzero_ps();
int32_t mask_src[] = {
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
0,
0,
0,
0,
0,
0,
0,
0,
};
__m256i imask =
_mm256_loadu_si256((__m256i const*)(mask_src + 8 - nb));
__m256 mC = _mm256_maskload_ps(C + i * n + j, imask);
for (int l2 = l; l2 < l + kb; l2++) {
__m256 mA_fp16 = _mm256_broadcast_ss(A_fp16 + i * k + l2);
__m256 mB_fp16 = _mm256_maskload_ps(B_fp16 + l2 * n + j, imask);
if (use_acc_fp16) {
mA_fp16 = clamp_subnormals(mA_fp16, epsilon.f);
mB_fp16 = clamp_subnormals(mB_fp16, epsilon.f);
}
__m256 mAB = _mm256_mul_ps(mA_fp16, mB_fp16);
if (use_acc_fp16) {
__m256 mAB_fp16 = _mm256_cvtph_ps(_mm256_cvtps_ph(mAB, 0));
mAB_fp16 = clamp_subnormals(mAB_fp16, epsilon.f);
if (use_temp_accumulator) {
mC_temp = _mm256_add_ps(mC_temp, mAB_fp16);
mC_temp = _mm256_cvtph_ps(_mm256_cvtps_ph(mC_temp, 0));
} else {
mC = _mm256_add_ps(mC, mAB_fp16);
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
}
} else {
if (use_temp_accumulator) {
mC_temp = _mm256_add_ps(mC_temp, mAB);
} else {
mC = _mm256_add_ps(mC, mAB);
}
}
if (use_acc_fp16) {
mC = clamp_subnormals(mC, epsilon.f);
}
}
if (use_temp_accumulator) {
if (use_acc_fp16) {
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
mC = _mm256_add_ps(mC, mC_temp);
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
} else {
mC = _mm256_add_ps(mC, mC_temp);
}
}
_mm256_maskstore_ps(C + i * n + j, imask, mC);
}
}
}
}
if (!use_acc_fp16) {
constexpr int kSize=8;
int i = 0;
for (; i + kSize <= C_size; i+= kSize) {
__m256 mC = _mm256_loadu_ps(C + i);
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
_mm256_storeu_ps(C + i, mC);
}
if (i < C_size){
vector<float> tmp(8);
for (int kk =0; kk + i < C_size; kk++) {
tmp[kk] = C[i + kk];
}
__m256 mC = _mm256_loadu_ps(tmp.data());
mC = _mm256_cvtph_ps(_mm256_cvtps_ph(mC, 0));
_mm256_storeu_ps(tmp.data(), mC);
for (int kk =0; kk + i < C_size; kk++) {
C[i + kk] = tmp[kk];
}
}
}
#ifdef LOG_LEVEL_FOR_FBFCPACkEDACC16_PERFORmAnCE_LOG
clock_t end = clock();
double elapsed_secs = double(end - begin) / CLOCkS_PER_SEC;
VLOG(LOG_LEVEL_FOR_FBFCPACKEDACC16_ACCURACY_LOG)
<< "cblas_gemm_compute_acc16 run time = " << elapsed_secs << endl;
#endif
}
void custom_fp16_gemv(
const bool use_acc_fp16,
const bool use_custom_acc32,
const bool use_temp_accumulator,
const CBLAS_TRANSPOSE trans_A,
const int M,
const int N,
const float alpha,
const float* A,
const float* x,
const float beta,
float* y,
CPUContext* context) {
if (use_acc_fp16) {
custom_fp16_gemm_with_trans(
trans_A,
CblasNoTrans,
M,
1,
N,
A,
x,
beta,
y,
true /* use acc_fp16 */,
use_temp_accumulator);
} else if (use_custom_acc32 && use_temp_accumulator) {
custom_fp16_gemm_with_trans(
trans_A,
CblasNoTrans,
M,
1,
N,
A,
x,
beta,
y,
false /* use acc_fp32 */,
use_temp_accumulator);
} else {
math::Gemv<float, CPUContext>(trans_A, M, N, alpha, A, x, beta, y, context);
}
}
void custom_fp16_gemm_batched(
const bool use_acc_fp16,
const bool use_custom_acc32,
const bool use_temp_accumulator,
const CBLAS_TRANSPOSE trans_A,
const CBLAS_TRANSPOSE trans_B,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
const float** A,
const float** B,
const float beta,
float** C,
CPUContext* context) {
if (!use_acc_fp16 && (!use_custom_acc32 || !use_temp_accumulator)) {
math::GemmBatched<float, CPUContext>(
trans_A, trans_B, batch_size, M, N, K, alpha, A, B, beta, C, context);
return;
}
for (int i = 0; i < batch_size; ++i) {
if (use_acc_fp16) {
custom_fp16_gemm_with_trans(
trans_A,
trans_B,
M,
K,
N,
A[i],
B[i],
beta,
C[i],
true /* use acc_fp16 */,
use_temp_accumulator);
} else {
CAFFE_ENFORCE(use_custom_acc32 && use_temp_accumulator);
custom_fp16_gemm_with_trans(
trans_A,
trans_B,
M,
K,
N,
A[i],
B[i],
beta,
C[i],
false /* use acc_fp32 */,
use_temp_accumulator);
}
}
}
void custom_fp16_gemm_strided_batched(
const bool use_acc_fp16,
const bool use_custom_acc32,
const bool use_temp_accumulator,
const CBLAS_TRANSPOSE trans_A,
const CBLAS_TRANSPOSE trans_B,
const int batch_size,
const int M,
const int N,
const int K,
const float alpha /* unused */,
const float* A,
const int A_stride,
const float* B,
const int B_stride,
const float beta,
float* C,
const int C_stride,
CPUContext* context) {
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
if (use_acc_fp16) {
custom_fp16_gemm_with_trans(
trans_A,
trans_B,
M,
K,
N,
A,
B,
beta,
C,
true /* use_acc_fp16 */,
use_temp_accumulator);
} else {
custom_fp16_gemm_with_trans(
trans_A,
trans_B,
M,
K,
N,
A,
B,
beta,
C,
false /* use acc_fp32*/,
use_temp_accumulator);
}
A += A_stride;
B += B_stride;
C += C_stride;
}
}
} // namespace caffe2