add bf16 in fp32 out fast path for embedingbag in caffe2 perfkernel (#89198)
Add BF16 in FP32 out kernel into Caffe2 emb perfkernels. And also update the python code-gen files to generate the kernel.
The ut will be covered in the next PR(#89199) in this stack ( Tested by nn.EmbeddingBag with BF16 data type)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89198
Approved by: https://github.com/jgong5, https://github.com/kit1980
diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc
index 2c9900b..48c869e 100644
--- a/caffe2/perfkernels/embedding_lookup_idx.cc
+++ b/caffe2/perfkernels/embedding_lookup_idx.cc
@@ -1,5 +1,6 @@
#include "caffe2/perfkernels/embedding_lookup_idx.h"
+#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>
#include "caffe2/core/common.h"
@@ -214,6 +215,8 @@
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
+EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
+EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
@@ -221,6 +224,8 @@
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
+EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
+EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
index 674af83..3ed48a1 100644
--- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
+++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
@@ -6,6 +6,7 @@
//// --------------------------
#include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
#include <immintrin.h>
namespace caffe2 {
@@ -341,6 +342,7 @@
}
} else {
// generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
float* op = &out[rangeIndex * block_size];
int64_t j = 0;
@@ -471,6 +473,7 @@
bool normalize_by_lengths,
float* out) {
const int64_t prefdist_T0 = 16;
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
const int64_t fused_block_size = block_size + 0;
int64_t dataInd = 0;
if (block_size == 128) {
@@ -511,7 +514,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const float* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -626,7 +631,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const float* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -701,7 +708,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const float* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -756,7 +765,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const float* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -780,6 +791,7 @@
}
} else {
// generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
float* op = &out[rangeIndex * block_size];
int64_t j = 0;
@@ -807,7 +819,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const float* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1477,6 +1491,7 @@
bool normalize_by_lengths,
float* out) {
const int64_t prefdist_T0 = 16;
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
const int64_t fused_block_size = block_size + 0;
int64_t dataInd = 0;
if (block_size == 128) {
@@ -1517,7 +1532,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const at::Half* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1692,7 +1709,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const at::Half* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1797,7 +1816,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const at::Half* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1867,7 +1888,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const at::Half* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1928,7 +1951,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const at::Half* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2022,6 +2047,1270 @@
}
template <bool IS_WEIGHT_POSITIONAL>
+static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int* indices,
+ const int* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ const int prefdist_T0 = 16;
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ const int fused_block_size = block_size + 0;
+ int64_t dataInd = 0;
+ if (block_size == 128) {
+ // unrolling 16 times
+ for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (32)))),
+ 16)),
+ vop32);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (40)))),
+ 16)),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (48)))),
+ 16)),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (56)))),
+ 16)),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (64)))),
+ 16)),
+ vop64);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (72)))),
+ 16)),
+ vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (80)))),
+ 16)),
+ vop80);
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (88)))),
+ 16)),
+ vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (96)))),
+ 16)),
+ vop96);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (104)))),
+ 16)),
+ vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (112)))),
+ 16)),
+ vop112);
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (120)))),
+ 16)),
+ vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (32)))),
+ 16)),
+ vop32);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (40)))),
+ 16)),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (48)))),
+ 16)),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (56)))),
+ 16)),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
+ alignas(64) at::BFloat16 vtmp1[8] = {0};
+ for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ int64_t j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ 16)),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
+ 16));
+ op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
+ }
+ }
+ if (normalize_by_lengths && length) {
+ float len_inv = 1.0f / length;
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+ return dataInd == index_size;
+}
+bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int* indices,
+ const int* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(
+ block_size,
+ output_size,
+ index_size,
+ data_size,
+ input,
+ indices,
+ offsets,
+ weights,
+ scale_bias,
+ normalize_by_lengths,
+ out);
+}
+bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int* indices,
+ const int* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(
+ block_size,
+ output_size,
+ index_size,
+ data_size,
+ input,
+ indices,
+ offsets,
+ weights,
+ scale_bias,
+ normalize_by_lengths,
+ out);
+}
+
+template <bool IS_WEIGHT_POSITIONAL>
+static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int64_t* indices,
+ const int64_t* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ const int64_t fused_block_size = block_size + 0;
+ int64_t dataInd = 0;
+ if (block_size == 128) {
+ // unrolling 16 times
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (32)))),
+ 16)),
+ vop32);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (40)))),
+ 16)),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (48)))),
+ 16)),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (56)))),
+ 16)),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (64)))),
+ 16)),
+ vop64);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (72)))),
+ 16)),
+ vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (80)))),
+ 16)),
+ vop80);
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (88)))),
+ 16)),
+ vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (96)))),
+ 16)),
+ vop96);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (104)))),
+ 16)),
+ vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (112)))),
+ 16)),
+ vop112);
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (120)))),
+ 16)),
+ vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (32)))),
+ 16)),
+ vop32);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (40)))),
+ 16)),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (48)))),
+ 16)),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (56)))),
+ 16)),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (16)))),
+ 16)),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (24)))),
+ 16)),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (0)))),
+ 16)),
+ vop0);
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(ip + (8)))),
+ 16)),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (!normalize_by_lengths || length == 0) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
+ alignas(64) at::BFloat16 vtmp1[8] = {0};
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ int64_t j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ if (dataInd != offsets[rangeIndex] - offsets[0]) {
+ return false;
+ }
+ int64_t end_offset = offsets[rangeIndex + 1];
+ int64_t length = end_offset - offsets[rangeIndex];
+ for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ if (idx < 0 || idx >= data_size) {
+ return false;
+ }
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const at::BFloat16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ ? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+ return false;
+ }
+ const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ 16)),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch(
+ reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
+ 16));
+ op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
+ }
+ }
+ if (normalize_by_lengths && length) {
+ float len_inv = 1.0f / length;
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+ return dataInd == index_size;
+}
+bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int64_t* indices,
+ const int64_t* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(
+ block_size,
+ output_size,
+ index_size,
+ data_size,
+ input,
+ indices,
+ offsets,
+ weights,
+ scale_bias,
+ normalize_by_lengths,
+ out);
+}
+bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(
+ const int64_t block_size,
+ const int64_t output_size,
+ const int64_t index_size,
+ const int64_t data_size,
+ const at::BFloat16* input,
+ const int64_t* indices,
+ const int64_t* offsets,
+ const float* weights,
+ const float* scale_bias,
+ bool normalize_by_lengths,
+ float* out) {
+ return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(
+ block_size,
+ output_size,
+ index_size,
+ data_size,
+ input,
+ indices,
+ offsets,
+ weights,
+ scale_bias,
+ normalize_by_lengths,
+ out);
+}
+
+template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
const int64_t block_size,
const int64_t output_size,
@@ -2483,6 +3772,7 @@
}
} else {
// generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
float* op = &out[rangeIndex * block_size];
int64_t j = 0;
@@ -2621,6 +3911,7 @@
bool normalize_by_lengths,
float* out) {
const int64_t prefdist_T0 = 16;
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
const int64_t fused_block_size = block_size + 0;
int64_t dataInd = 0;
if (block_size == 128) {
@@ -2666,7 +3957,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const uint8_t* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2844,7 +4137,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const uint8_t* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2953,7 +4248,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const uint8_t* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -3028,7 +4325,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const uint8_t* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -3060,6 +4359,7 @@
}
} else {
// generic code
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
float* op = &out[rangeIndex * block_size];
int64_t j = 0;
@@ -3092,7 +4392,9 @@
__m256 vwgt = _mm256_set1_ps(wgt);
const uint8_t* ip = &input[idx * fused_block_size];
const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? (dataInd + prefdist_T0)
+ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: dataInd;
const int64_t idx_pref_T0 = indices[next_T0];
if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py
index 402f3bb..7e4208c 100644
--- a/caffe2/perfkernels/hp_emblookup_codegen.py
+++ b/caffe2/perfkernels/hp_emblookup_codegen.py
@@ -4,7 +4,7 @@
import sys
-sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
+sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1}
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
@@ -24,6 +24,16 @@
" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
" vop%d);" % (regid, regid, regid)
)
+ elif InType == "at::BFloat16":
+ code.append(
+ " vop%d = _mm256_fmadd_ps(\n"
+ " vwgt,\n"
+ " _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+ " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
+ " reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
+ " 16)),\n" # noqa
+ " vop%d);" % (regid, regid, regid)
+ )
elif InType == "uint8_t":
code.append(
" vop%d = _mm256_fmadd_ps(\n"
@@ -104,6 +114,7 @@
if InType == "uint8_t":
code.append(" " + OutType + " wgt = 1.f;")
+ code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
code.append(" " + OutType + " bio;")
code.append(" if (weights) {")
code.append(
@@ -133,7 +144,10 @@
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
code.append(
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
- " ? (dataInd + prefdist_T0)\n : dataInd;".format(
+ " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+ " ? (dataInd + prefdist_T0)\n"
+ " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+ " : dataInd;".format(
IndexType
)
)
@@ -206,6 +220,18 @@
" reinterpret_cast<const __m128i*>(&ip[j]))),\n"
" _mm256_loadu_ps(&op[j])));"
)
+ elif InType == "at::BFloat16":
+ code.append(
+ " _mm256_storeu_ps(\n"
+ " &op[j],\n"
+ " _mm256_fmadd_ps(\n"
+ " vwgt,\n"
+ " _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+ " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
+ " reinterpret_cast<const __m128i*>(&ip[j]))),\n"
+ " 16)),\n"
+ " _mm256_loadu_ps(&op[j])));"
+ )
elif InType == "uint8_t":
code.append(
" _mm256_storeu_ps(\n"
@@ -229,7 +255,8 @@
code = []
if InType == "at::Half":
code.append(" alignas(64) at::Half vtmp1[8] = {0};")
-
+ if InType == "at::BFloat16":
+ code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
if use_offsets:
@@ -291,6 +318,7 @@
if InType == "uint8_t":
code.append(" " + OutType + " wgt = 1.f;")
+ code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
code.append(" " + OutType + " bio;")
code.append(" if (weights) {")
code.append(
@@ -320,7 +348,10 @@
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
code.append(
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
- " ? (dataInd + prefdist_T0)\n : dataInd;".format(
+ " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+ " ? (dataInd + prefdist_T0)\n"
+ " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+ " : dataInd;".format(
IndexType
)
)
@@ -351,6 +382,14 @@
" _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
)
code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
+ elif InType == "at::BFloat16":
+ code.append(" vtmp1[0] = ip[j];")
+ code.append(
+ " __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+ " _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n"
+ " 16));"
+ )
+ code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
elif InType == "uint8_t":
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
else:
@@ -408,6 +447,8 @@
["int64_t", "int64_t", "float", "float", "float", "float"],
["int32_t", "int", "half", "at::Half", "float", "float"],
["int64_t", "int64_t", "half", "at::Half", "float", "float"],
+ ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"],
+ ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
]
@@ -422,6 +463,7 @@
code.append("//// --------------------------\n")
code.append("#include <c10/util/Half.h>")
+code.append("#include <c10/util/BFloat16.h>")
code.append("#include <immintrin.h>")
code.append("namespace caffe2 {\n")
@@ -461,6 +503,7 @@
code += args
code.append(" const " + IndexType + " prefdist_T0 = 16;")
+ code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
# block_size is the number of elements and fused_block_size is the size of
# an entire row, including scale and bias.
offset = (8 // sizeof[InType]) if opts.fused else 0
@@ -484,6 +527,7 @@
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
code.append(" } else {")
code.append(" // generic code")
+ code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
code.append(" }")
code.append(" return dataInd == index_size;")