add bf16 in fp32 out fast path for embedingbag in caffe2 perfkernel (#89198)

Add BF16 in FP32 out kernel into Caffe2 emb perfkernels. And also update the python code-gen files to generate the kernel.
The ut will be covered in the next PR(#89199) in this stack ( Tested by nn.EmbeddingBag with BF16 data type)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89198
Approved by: https://github.com/jgong5, https://github.com/kit1980
diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc
index 2c9900b..48c869e 100644
--- a/caffe2/perfkernels/embedding_lookup_idx.cc
+++ b/caffe2/perfkernels/embedding_lookup_idx.cc
@@ -1,5 +1,6 @@
 #include "caffe2/perfkernels/embedding_lookup_idx.h"
 
+#include <c10/util/BFloat16.h>
 #include <c10/util/Half.h>
 #include <c10/util/irange.h>
 #include "caffe2/core/common.h"
@@ -214,6 +215,8 @@
 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
+EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
+EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
 
@@ -221,6 +224,8 @@
 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
+EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
+EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
 
diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
index 674af83..3ed48a1 100644
--- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
+++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc
@@ -6,6 +6,7 @@
 //// --------------------------
 
 #include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
 #include <immintrin.h>
 namespace caffe2 {
 
@@ -341,6 +342,7 @@
     }
   } else {
     // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
       float* op = &out[rangeIndex * block_size];
       int64_t j = 0;
@@ -471,6 +473,7 @@
     bool normalize_by_lengths,
     float* out) {
   const int64_t prefdist_T0 = 16;
+  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
   const int64_t fused_block_size = block_size + 0;
   int64_t dataInd = 0;
   if (block_size == 128) {
@@ -511,7 +514,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const float* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -626,7 +631,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const float* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -701,7 +708,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const float* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -756,7 +765,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const float* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -780,6 +791,7 @@
     }
   } else {
     // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
       float* op = &out[rangeIndex * block_size];
       int64_t j = 0;
@@ -807,7 +819,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const float* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1477,6 +1491,7 @@
     bool normalize_by_lengths,
     float* out) {
   const int64_t prefdist_T0 = 16;
+  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
   const int64_t fused_block_size = block_size + 0;
   int64_t dataInd = 0;
   if (block_size == 128) {
@@ -1517,7 +1532,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const at::Half* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1692,7 +1709,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const at::Half* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1797,7 +1816,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const at::Half* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1867,7 +1888,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const at::Half* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -1928,7 +1951,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const at::Half* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2022,6 +2047,1270 @@
 }
 
 template <bool IS_WEIGHT_POSITIONAL>
+static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int* indices,
+    const int* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  const int prefdist_T0 = 16;
+  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+  const int fused_block_size = block_size + 0;
+  int64_t dataInd = 0;
+  if (block_size == 128) {
+    // unrolling 16 times
+    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      __m256 vop32 = _mm256_setzero_ps();
+      __m256 vop40 = _mm256_setzero_ps();
+      __m256 vop48 = _mm256_setzero_ps();
+      __m256 vop56 = _mm256_setzero_ps();
+      __m256 vop64 = _mm256_setzero_ps();
+      __m256 vop72 = _mm256_setzero_ps();
+      __m256 vop80 = _mm256_setzero_ps();
+      __m256 vop88 = _mm256_setzero_ps();
+      __m256 vop96 = _mm256_setzero_ps();
+      __m256 vop104 = _mm256_setzero_ps();
+      __m256 vop112 = _mm256_setzero_ps();
+      __m256 vop120 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+        vop32 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (32)))),
+                16)),
+            vop32);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+        vop40 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (40)))),
+                16)),
+            vop40);
+        // skip unnecessary prefetch of (&ip_next_T0[40])
+        vop48 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (48)))),
+                16)),
+            vop48);
+        // skip unnecessary prefetch of (&ip_next_T0[48])
+        vop56 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (56)))),
+                16)),
+            vop56);
+        // skip unnecessary prefetch of (&ip_next_T0[56])
+        vop64 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (64)))),
+                16)),
+            vop64);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
+        vop72 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (72)))),
+                16)),
+            vop72);
+        // skip unnecessary prefetch of (&ip_next_T0[72])
+        vop80 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (80)))),
+                16)),
+            vop80);
+        // skip unnecessary prefetch of (&ip_next_T0[80])
+        vop88 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (88)))),
+                16)),
+            vop88);
+        // skip unnecessary prefetch of (&ip_next_T0[88])
+        vop96 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (96)))),
+                16)),
+            vop96);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
+        vop104 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (104)))),
+                16)),
+            vop104);
+        // skip unnecessary prefetch of (&ip_next_T0[104])
+        vop112 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (112)))),
+                16)),
+            vop112);
+        // skip unnecessary prefetch of (&ip_next_T0[112])
+        vop120 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (120)))),
+                16)),
+            vop120);
+        // skip unnecessary prefetch of (&ip_next_T0[120])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+        _mm256_storeu_ps(&op[32], vop32);
+        _mm256_storeu_ps(&op[40], vop40);
+        _mm256_storeu_ps(&op[48], vop48);
+        _mm256_storeu_ps(&op[56], vop56);
+        _mm256_storeu_ps(&op[64], vop64);
+        _mm256_storeu_ps(&op[72], vop72);
+        _mm256_storeu_ps(&op[80], vop80);
+        _mm256_storeu_ps(&op[88], vop88);
+        _mm256_storeu_ps(&op[96], vop96);
+        _mm256_storeu_ps(&op[104], vop104);
+        _mm256_storeu_ps(&op[112], vop112);
+        _mm256_storeu_ps(&op[120], vop120);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+      }
+    }
+  } else if (block_size == 64) {
+    // unrolling 8 times
+    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      __m256 vop32 = _mm256_setzero_ps();
+      __m256 vop40 = _mm256_setzero_ps();
+      __m256 vop48 = _mm256_setzero_ps();
+      __m256 vop56 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+        vop32 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (32)))),
+                16)),
+            vop32);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+        vop40 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (40)))),
+                16)),
+            vop40);
+        // skip unnecessary prefetch of (&ip_next_T0[40])
+        vop48 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (48)))),
+                16)),
+            vop48);
+        // skip unnecessary prefetch of (&ip_next_T0[48])
+        vop56 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (56)))),
+                16)),
+            vop56);
+        // skip unnecessary prefetch of (&ip_next_T0[56])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+        _mm256_storeu_ps(&op[32], vop32);
+        _mm256_storeu_ps(&op[40], vop40);
+        _mm256_storeu_ps(&op[48], vop48);
+        _mm256_storeu_ps(&op[56], vop56);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+      }
+    }
+  } else if (block_size == 32) {
+    // unrolling 4 times
+    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+      }
+    }
+  } else if (block_size == 16) {
+    // unrolling 2 times
+    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+      }
+    }
+  } else {
+    // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
+    alignas(64) at::BFloat16 vtmp1[8] = {0};
+    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      int64_t j = 0;
+      for (; j + 8 <= block_size; j += 8) {
+        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+      }
+      for (; j < block_size; j++) {
+        op[j] = 0.0f;
+      }
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        j = 0;
+        for (; j + 8 <= block_size; j += 8) {
+          _mm256_storeu_ps(
+              &op[j],
+              _mm256_fmadd_ps(
+                  vwgt,
+                  _mm256_castsi256_ps(_mm256_slli_epi32(
+                      _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                          reinterpret_cast<const __m128i*>(&ip[j]))),
+                      16)),
+                  _mm256_loadu_ps(&op[j])));
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
+        }
+        for (; j < block_size; j++) {
+          vtmp1[0] = ip[j];
+          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
+              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
+              16));
+          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
+        }
+      }
+      if (normalize_by_lengths && length) {
+        float len_inv = 1.0f / length;
+        __m256 vlen_inv = _mm256_set1_ps(len_inv);
+        j = 0;
+        for (; j + 8 <= block_size; j += 8) {
+          _mm256_storeu_ps(
+              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+        }
+        for (; j < block_size; j++) {
+          op[j] = len_inv * op[j];
+        }
+      }
+    }
+  }
+  return dataInd == index_size;
+}
+bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int* indices,
+    const int* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(
+      block_size,
+      output_size,
+      index_size,
+      data_size,
+      input,
+      indices,
+      offsets,
+      weights,
+      scale_bias,
+      normalize_by_lengths,
+      out);
+}
+bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int* indices,
+    const int* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(
+      block_size,
+      output_size,
+      index_size,
+      data_size,
+      input,
+      indices,
+      offsets,
+      weights,
+      scale_bias,
+      normalize_by_lengths,
+      out);
+}
+
+template <bool IS_WEIGHT_POSITIONAL>
+static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int64_t* indices,
+    const int64_t* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  const int64_t prefdist_T0 = 16;
+  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+  const int64_t fused_block_size = block_size + 0;
+  int64_t dataInd = 0;
+  if (block_size == 128) {
+    // unrolling 16 times
+    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      __m256 vop32 = _mm256_setzero_ps();
+      __m256 vop40 = _mm256_setzero_ps();
+      __m256 vop48 = _mm256_setzero_ps();
+      __m256 vop56 = _mm256_setzero_ps();
+      __m256 vop64 = _mm256_setzero_ps();
+      __m256 vop72 = _mm256_setzero_ps();
+      __m256 vop80 = _mm256_setzero_ps();
+      __m256 vop88 = _mm256_setzero_ps();
+      __m256 vop96 = _mm256_setzero_ps();
+      __m256 vop104 = _mm256_setzero_ps();
+      __m256 vop112 = _mm256_setzero_ps();
+      __m256 vop120 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int64_t idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int64_t idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+        vop32 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (32)))),
+                16)),
+            vop32);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+        vop40 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (40)))),
+                16)),
+            vop40);
+        // skip unnecessary prefetch of (&ip_next_T0[40])
+        vop48 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (48)))),
+                16)),
+            vop48);
+        // skip unnecessary prefetch of (&ip_next_T0[48])
+        vop56 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (56)))),
+                16)),
+            vop56);
+        // skip unnecessary prefetch of (&ip_next_T0[56])
+        vop64 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (64)))),
+                16)),
+            vop64);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
+        vop72 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (72)))),
+                16)),
+            vop72);
+        // skip unnecessary prefetch of (&ip_next_T0[72])
+        vop80 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (80)))),
+                16)),
+            vop80);
+        // skip unnecessary prefetch of (&ip_next_T0[80])
+        vop88 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (88)))),
+                16)),
+            vop88);
+        // skip unnecessary prefetch of (&ip_next_T0[88])
+        vop96 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (96)))),
+                16)),
+            vop96);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
+        vop104 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (104)))),
+                16)),
+            vop104);
+        // skip unnecessary prefetch of (&ip_next_T0[104])
+        vop112 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (112)))),
+                16)),
+            vop112);
+        // skip unnecessary prefetch of (&ip_next_T0[112])
+        vop120 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (120)))),
+                16)),
+            vop120);
+        // skip unnecessary prefetch of (&ip_next_T0[120])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+        _mm256_storeu_ps(&op[32], vop32);
+        _mm256_storeu_ps(&op[40], vop40);
+        _mm256_storeu_ps(&op[48], vop48);
+        _mm256_storeu_ps(&op[56], vop56);
+        _mm256_storeu_ps(&op[64], vop64);
+        _mm256_storeu_ps(&op[72], vop72);
+        _mm256_storeu_ps(&op[80], vop80);
+        _mm256_storeu_ps(&op[88], vop88);
+        _mm256_storeu_ps(&op[96], vop96);
+        _mm256_storeu_ps(&op[104], vop104);
+        _mm256_storeu_ps(&op[112], vop112);
+        _mm256_storeu_ps(&op[120], vop120);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+      }
+    }
+  } else if (block_size == 64) {
+    // unrolling 8 times
+    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      __m256 vop32 = _mm256_setzero_ps();
+      __m256 vop40 = _mm256_setzero_ps();
+      __m256 vop48 = _mm256_setzero_ps();
+      __m256 vop56 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int64_t idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int64_t idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+        vop32 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (32)))),
+                16)),
+            vop32);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
+        vop40 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (40)))),
+                16)),
+            vop40);
+        // skip unnecessary prefetch of (&ip_next_T0[40])
+        vop48 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (48)))),
+                16)),
+            vop48);
+        // skip unnecessary prefetch of (&ip_next_T0[48])
+        vop56 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (56)))),
+                16)),
+            vop56);
+        // skip unnecessary prefetch of (&ip_next_T0[56])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+        _mm256_storeu_ps(&op[32], vop32);
+        _mm256_storeu_ps(&op[40], vop40);
+        _mm256_storeu_ps(&op[48], vop48);
+        _mm256_storeu_ps(&op[56], vop56);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+      }
+    }
+  } else if (block_size == 32) {
+    // unrolling 4 times
+    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      __m256 vop16 = _mm256_setzero_ps();
+      __m256 vop24 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int64_t idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int64_t idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+        vop16 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (16)))),
+                16)),
+            vop16);
+        // skip unnecessary prefetch of (&ip_next_T0[16])
+        vop24 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (24)))),
+                16)),
+            vop24);
+        // skip unnecessary prefetch of (&ip_next_T0[24])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+        _mm256_storeu_ps(&op[16], vop16);
+        _mm256_storeu_ps(&op[24], vop24);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+      }
+    }
+  } else if (block_size == 16) {
+    // unrolling 2 times
+    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      __m256 vop0 = _mm256_setzero_ps();
+      __m256 vop8 = _mm256_setzero_ps();
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int64_t idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int64_t idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        vop0 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (0)))),
+                16)),
+            vop0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
+        vop8 = _mm256_fmadd_ps(
+            vwgt,
+            _mm256_castsi256_ps(_mm256_slli_epi32(
+                _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                    reinterpret_cast<const __m128i*>(ip + (8)))),
+                16)),
+            vop8);
+        // skip unnecessary prefetch of (&ip_next_T0[8])
+      }
+      if (!normalize_by_lengths || length == 0) {
+        _mm256_storeu_ps(&op[0], vop0);
+        _mm256_storeu_ps(&op[8], vop8);
+      } else {
+        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
+        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+      }
+    }
+  } else {
+    // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
+    alignas(64) at::BFloat16 vtmp1[8] = {0};
+    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+      float* op = &out[rangeIndex * block_size];
+      int64_t j = 0;
+      for (; j + 8 <= block_size; j += 8) {
+        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+      }
+      for (; j < block_size; j++) {
+        op[j] = 0.0f;
+      }
+      if (dataInd != offsets[rangeIndex] - offsets[0]) {
+        return false;
+      }
+      int64_t end_offset = offsets[rangeIndex + 1];
+      int64_t length = end_offset - offsets[rangeIndex];
+      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
+           ++dataInd) {
+        const int64_t idx = indices[dataInd];
+        if (idx < 0 || idx >= data_size) {
+          return false;
+        }
+        float wgt = 1.f;
+        if (weights) {
+          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
+        }
+        __m256 vwgt = _mm256_set1_ps(wgt);
+        const at::BFloat16* ip = &input[idx * fused_block_size];
+        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
+            : dataInd;
+        const int64_t idx_pref_T0 = indices[next_T0];
+        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
+          return false;
+        }
+        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+        j = 0;
+        for (; j + 8 <= block_size; j += 8) {
+          _mm256_storeu_ps(
+              &op[j],
+              _mm256_fmadd_ps(
+                  vwgt,
+                  _mm256_castsi256_ps(_mm256_slli_epi32(
+                      _mm256_cvtepu16_epi32(_mm_loadu_si128(
+                          reinterpret_cast<const __m128i*>(&ip[j]))),
+                      16)),
+                  _mm256_loadu_ps(&op[j])));
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
+        }
+        for (; j < block_size; j++) {
+          vtmp1[0] = ip[j];
+          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
+              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
+              16));
+          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
+        }
+      }
+      if (normalize_by_lengths && length) {
+        float len_inv = 1.0f / length;
+        __m256 vlen_inv = _mm256_set1_ps(len_inv);
+        j = 0;
+        for (; j + 8 <= block_size; j += 8) {
+          _mm256_storeu_ps(
+              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+        }
+        for (; j < block_size; j++) {
+          op[j] = len_inv * op[j];
+        }
+      }
+    }
+  }
+  return dataInd == index_size;
+}
+bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int64_t* indices,
+    const int64_t* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(
+      block_size,
+      output_size,
+      index_size,
+      data_size,
+      input,
+      indices,
+      offsets,
+      weights,
+      scale_bias,
+      normalize_by_lengths,
+      out);
+}
+bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(
+    const int64_t block_size,
+    const int64_t output_size,
+    const int64_t index_size,
+    const int64_t data_size,
+    const at::BFloat16* input,
+    const int64_t* indices,
+    const int64_t* offsets,
+    const float* weights,
+    const float* scale_bias,
+    bool normalize_by_lengths,
+    float* out) {
+  return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(
+      block_size,
+      output_size,
+      index_size,
+      data_size,
+      input,
+      indices,
+      offsets,
+      weights,
+      scale_bias,
+      normalize_by_lengths,
+      out);
+}
+
+template <bool IS_WEIGHT_POSITIONAL>
 static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
     const int64_t block_size,
     const int64_t output_size,
@@ -2483,6 +3772,7 @@
     }
   } else {
     // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
       float* op = &out[rangeIndex * block_size];
       int64_t j = 0;
@@ -2621,6 +3911,7 @@
     bool normalize_by_lengths,
     float* out) {
   const int64_t prefdist_T0 = 16;
+  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
   const int64_t fused_block_size = block_size + 0;
   int64_t dataInd = 0;
   if (block_size == 128) {
@@ -2666,7 +3957,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const uint8_t* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2844,7 +4137,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const uint8_t* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -2953,7 +4248,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const uint8_t* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -3028,7 +4325,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const uint8_t* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
@@ -3060,6 +4359,7 @@
     }
   } else {
     // generic code
+    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
       float* op = &out[rangeIndex * block_size];
       int64_t j = 0;
@@ -3092,7 +4392,9 @@
         __m256 vwgt = _mm256_set1_ps(wgt);
         const uint8_t* ip = &input[idx * fused_block_size];
         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             ? (dataInd + prefdist_T0)
+            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
             : dataInd;
         const int64_t idx_pref_T0 = indices[next_T0];
         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py
index 402f3bb..7e4208c 100644
--- a/caffe2/perfkernels/hp_emblookup_codegen.py
+++ b/caffe2/perfkernels/hp_emblookup_codegen.py
@@ -4,7 +4,7 @@
 import sys
 
 
-sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
+sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1}
 
 
 def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
@@ -24,6 +24,16 @@
                 "                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n"  # noqa
                 "            vop%d);" % (regid, regid, regid)
             )
+        elif InType == "at::BFloat16":
+            code.append(
+                "        vop%d = _mm256_fmadd_ps(\n"
+                "            vwgt,\n"
+                "            _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+                "                _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
+                "                    reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
+                "                16)),\n"  # noqa
+                "            vop%d);" % (regid, regid, regid)
+            )
         elif InType == "uint8_t":
             code.append(
                 "        vop%d = _mm256_fmadd_ps(\n"
@@ -104,6 +114,7 @@
 
     if InType == "uint8_t":
         code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
         code.append("        " + OutType + " bio;")
         code.append("        if (weights) {")
         code.append(
@@ -133,7 +144,10 @@
     code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
     code.append(
         "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
-        "            ? (dataInd + prefdist_T0)\n            : dataInd;".format(
+        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+        "            ? (dataInd + prefdist_T0)\n"
+        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+        "            : dataInd;".format(
             IndexType
         )
     )
@@ -206,6 +220,18 @@
                 "                      reinterpret_cast<const __m128i*>(&ip[j]))),\n"
                 "                  _mm256_loadu_ps(&op[j])));"
             )
+        elif InType == "at::BFloat16":
+            code.append(
+                "          _mm256_storeu_ps(\n"
+                "              &op[j],\n"
+                "              _mm256_fmadd_ps(\n"
+                "                  vwgt,\n"
+                "                  _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+                "                      _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
+                "                          reinterpret_cast<const __m128i*>(&ip[j]))),\n"
+                "                      16)),\n"
+                "                  _mm256_loadu_ps(&op[j])));"
+            )
         elif InType == "uint8_t":
             code.append(
                 "          _mm256_storeu_ps(\n"
@@ -229,7 +255,8 @@
     code = []
     if InType == "at::Half":
         code.append("    alignas(64) at::Half vtmp1[8] = {0};")
-
+    if InType == "at::BFloat16":
+        code.append("    alignas(64) at::BFloat16 vtmp1[8] = {0};")
 
 
     if use_offsets:
@@ -291,6 +318,7 @@
 
     if InType == "uint8_t":
         code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
         code.append("        " + OutType + " bio;")
         code.append("        if (weights) {")
         code.append(
@@ -320,7 +348,10 @@
     code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
     code.append(
         "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
-        "            ? (dataInd + prefdist_T0)\n            : dataInd;".format(
+        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+        "            ? (dataInd + prefdist_T0)\n"
+        "            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
+        "            : dataInd;".format(
             IndexType
         )
     )
@@ -351,6 +382,14 @@
             "              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
         )
         code.append("          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
+    elif InType == "at::BFloat16":
+        code.append("          vtmp1[0] = ip[j];")
+        code.append(
+            "          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n"
+            "              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n"
+            "              16));"
+        )
+        code.append("          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
     elif InType == "uint8_t":
         code.append("          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
     else:
@@ -408,6 +447,8 @@
     ["int64_t", "int64_t", "float", "float", "float", "float"],
     ["int32_t", "int", "half", "at::Half", "float", "float"],
     ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
+    ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"],
+    ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
     ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
     ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
 ]
@@ -422,6 +463,7 @@
 code.append("//// --------------------------\n")
 
 code.append("#include <c10/util/Half.h>")
+code.append("#include <c10/util/BFloat16.h>")
 code.append("#include <immintrin.h>")
 
 code.append("namespace caffe2 {\n")
@@ -461,6 +503,7 @@
     code += args
 
     code.append("  const " + IndexType + " prefdist_T0 = 16;")
+    code.append("  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
     # block_size is the number of elements and fused_block_size is the size of
     # an entire row, including scale and bias.
     offset = (8 // sizeof[InType]) if opts.fused else 0
@@ -484,6 +527,7 @@
     code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
     code.append("  } else {")
     code.append("    // generic code")
+    code.append("    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
     code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
     code.append("  }")
     code.append("  return dataInd == index_size;")