Revert D17826873: Adding support to offsets based Fused8BitRowwiseEmbeddingLookup

Test Plan: revert-hammer

Differential Revision:
D17826873

Original commit changeset: 23c4a96d9252

fbshipit-source-id: 15ad64e49f922a859abc574b261ac0f857682ff4
diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc
deleted file mode 100644
index 9a0f240..0000000
--- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc
+++ /dev/null
@@ -1,3160 +0,0 @@
-//// --------------------------
-//// ATTENTION:
-//// THIS CODE IS AUTOGENERATED
-//// BY caffe2/caffe2/perfkernels/hp_emblookup_codegen.py
-//// DO NOT MODIFY!!!
-//// --------------------------
-
-#include <c10/util/Half.h>
-#include <immintrin.h>
-namespace caffe2 {
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int prefdist_T0 = 16;
-  const int fused_block_size = block_size + 2;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
-        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
-        vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
-        vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
-        vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
-        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          op[j] = std::fma(wgt, ip[j], op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_float_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int64_t prefdist_T0 = 16;
-  const int64_t fused_block_size = block_size + 2;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
-        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
-        vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
-        vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
-        vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
-        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
-        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const float* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          op[j] = std::fma(wgt, ip[j], op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_float_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const float* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_float_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int prefdist_T0 = 16;
-  const int fused_block_size = block_size + 4;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
-            vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
-            vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
-            vop48);
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
-            vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
-            vop64);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
-            vop72);
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
-            vop80);
-        // skip unnecessary prefetch of (&ip_next_T0[80])
-        vop88 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
-            vop88);
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
-            vop96);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
-        vop104 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
-            vop104);
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
-            vop112);
-        // skip unnecessary prefetch of (&ip_next_T0[112])
-        vop120 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
-            vop120);
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
-            vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
-            vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
-            vop48);
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
-            vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    alignas(64) at::Half vtmp1[8] = {0};
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt,
-                  _mm256_cvtph_ps(_mm_loadu_si128(
-                      reinterpret_cast<const __m128i*>(&ip[j]))),
-                  _mm256_loadu_ps(&op[j])));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          vtmp1[0] = ip[j];
-          __m256 vtmp2 =
-              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
-          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_half_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int64_t prefdist_T0 = 16;
-  const int64_t fused_block_size = block_size + 4;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
-            vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
-            vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
-            vop48);
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
-            vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
-            vop64);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
-            vop72);
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
-            vop80);
-        // skip unnecessary prefetch of (&ip_next_T0[80])
-        vop88 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
-            vop88);
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
-            vop96);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
-        vop104 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
-            vop104);
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
-            vop112);
-        // skip unnecessary prefetch of (&ip_next_T0[112])
-        vop120 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
-            vop120);
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
-            vop32);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
-            vop40);
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
-            vop48);
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
-            vop56);
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
-            vop16);
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
-            vop24);
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
-            vop0);
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtph_ps(
-                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
-            vop8);
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    alignas(64) at::Half vtmp1[8] = {0};
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const at::Half* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt,
-                  _mm256_cvtph_ps(_mm_loadu_si128(
-                      reinterpret_cast<const __m128i*>(&ip[j]))),
-                  _mm256_loadu_ps(&op[j])));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          vtmp1[0] = ip[j];
-          __m256 vtmp2 =
-              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
-          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_half_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const at::Half* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_half_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int prefdist_T0 = 16;
-  const int fused_block_size = block_size + 8;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
-            _mm256_add_ps(vop32, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[32])
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
-            _mm256_add_ps(vop40, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
-            _mm256_add_ps(vop48, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
-            _mm256_add_ps(vop56, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
-            _mm256_add_ps(vop64, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
-            _mm256_add_ps(vop72, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
-            _mm256_add_ps(vop80, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[80])
-        vop88 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
-            _mm256_add_ps(vop88, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
-            _mm256_add_ps(vop96, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[96])
-        vop104 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
-            _mm256_add_ps(vop104, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
-            _mm256_add_ps(vop112, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[112])
-        vop120 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
-            _mm256_add_ps(vop120, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
-            _mm256_add_ps(vop32, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[32])
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
-            _mm256_add_ps(vop40, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
-            _mm256_add_ps(vop48, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
-            _mm256_add_ps(vop56, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt,
-                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
-                      reinterpret_cast<const __m128i*>(&ip[j])))),
-                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-template <bool IS_WEIGHT_POSITIONAL>
-static bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  const int64_t prefdist_T0 = 16;
-  const int64_t fused_block_size = block_size + 8;
-  int64_t dataInd = 0;
-  if (block_size == 128) {
-    // unrolling 16 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      __m256 vop64 = _mm256_setzero_ps();
-      __m256 vop72 = _mm256_setzero_ps();
-      __m256 vop80 = _mm256_setzero_ps();
-      __m256 vop88 = _mm256_setzero_ps();
-      __m256 vop96 = _mm256_setzero_ps();
-      __m256 vop104 = _mm256_setzero_ps();
-      __m256 vop112 = _mm256_setzero_ps();
-      __m256 vop120 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
-            _mm256_add_ps(vop32, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[32])
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
-            _mm256_add_ps(vop40, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
-            _mm256_add_ps(vop48, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
-            _mm256_add_ps(vop56, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-        vop64 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
-            _mm256_add_ps(vop64, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
-        vop72 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
-            _mm256_add_ps(vop72, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[72])
-        vop80 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
-            _mm256_add_ps(vop80, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[80])
-        vop88 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
-            _mm256_add_ps(vop88, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[88])
-        vop96 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
-            _mm256_add_ps(vop96, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[96])
-        vop104 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
-            _mm256_add_ps(vop104, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[104])
-        vop112 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
-            _mm256_add_ps(vop112, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[112])
-        vop120 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
-            _mm256_add_ps(vop120, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[120])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-        _mm256_storeu_ps(&op[64], vop64);
-        _mm256_storeu_ps(&op[72], vop72);
-        _mm256_storeu_ps(&op[80], vop80);
-        _mm256_storeu_ps(&op[88], vop88);
-        _mm256_storeu_ps(&op[96], vop96);
-        _mm256_storeu_ps(&op[104], vop104);
-        _mm256_storeu_ps(&op[112], vop112);
-        _mm256_storeu_ps(&op[120], vop120);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
-        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
-        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
-        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
-        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
-        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
-        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
-        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
-      }
-    }
-  } else if (block_size == 64) {
-    // unrolling 8 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      __m256 vop32 = _mm256_setzero_ps();
-      __m256 vop40 = _mm256_setzero_ps();
-      __m256 vop48 = _mm256_setzero_ps();
-      __m256 vop56 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-        vop32 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
-            _mm256_add_ps(vop32, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[32])
-        vop40 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
-            _mm256_add_ps(vop40, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[40])
-        vop48 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
-            _mm256_add_ps(vop48, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[48])
-        vop56 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
-            _mm256_add_ps(vop56, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[56])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-        _mm256_storeu_ps(&op[32], vop32);
-        _mm256_storeu_ps(&op[40], vop40);
-        _mm256_storeu_ps(&op[48], vop48);
-        _mm256_storeu_ps(&op[56], vop56);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
-        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
-        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
-        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
-      }
-    }
-  } else if (block_size == 32) {
-    // unrolling 4 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      __m256 vop16 = _mm256_setzero_ps();
-      __m256 vop24 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-        vop16 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
-            _mm256_add_ps(vop16, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[16])
-        vop24 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
-            _mm256_add_ps(vop24, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[24])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-        _mm256_storeu_ps(&op[16], vop16);
-        _mm256_storeu_ps(&op[24], vop24);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
-        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
-      }
-    }
-  } else if (block_size == 16) {
-    // unrolling 2 times
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      __m256 vop0 = _mm256_setzero_ps();
-      __m256 vop8 = _mm256_setzero_ps();
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int64_t end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int64_t length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        vop0 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
-            _mm256_add_ps(vop0, vbio));
-        _mm_prefetch(
-            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
-        vop8 = _mm256_fmadd_ps(
-            vwgt,
-            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
-                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
-            _mm256_add_ps(vop8, vbio));
-        // skip unnecessary prefetch of (&ip_next_T0[8])
-      }
-      if (!normalize_by_lengths || length == 0) {
-        _mm256_storeu_ps(&op[0], vop0);
-        _mm256_storeu_ps(&op[8], vop8);
-      } else {
-        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
-        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
-        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
-      }
-    }
-  } else {
-    // generic code
-    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
-      float* op = &out[rangeIndex * block_size];
-      int64_t j = 0;
-      for (; j + 8 <= block_size; j += 8) {
-        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
-      }
-      for (; j < block_size; j++) {
-        op[j] = 0.0f;
-      }
-      if (dataInd != offsets[rangeIndex]) {
-        return false;
-      }
-      int end_offset =
-          (rangeIndex == output_size - 1 ? index_size
-                                         : offsets[rangeIndex + 1]);
-      int length = end_offset - offsets[rangeIndex];
-      for (int64_t start = dataInd; dataInd < end_offset; ++dataInd) {
-        const int64_t idx = indices[dataInd];
-        if (idx < 0 || idx >= data_size) {
-          return false;
-        }
-        float wgt = 1.f;
-        float bio;
-        if (weights) {
-          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
-        }
-        const float* scale_bias = reinterpret_cast<const float*>(
-            &input[idx * fused_block_size + block_size]);
-        bio = wgt * scale_bias[1];
-        wgt = wgt * scale_bias[0];
-        __m256 vbio = _mm256_set1_ps(bio);
-        __m256 vwgt = _mm256_set1_ps(wgt);
-        const uint8_t* ip = &input[idx * fused_block_size];
-        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
-            ? (dataInd + prefdist_T0)
-            : dataInd;
-        const int64_t idx_pref_T0 = indices[next_T0];
-        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
-          return false;
-        }
-        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j],
-              _mm256_fmadd_ps(
-                  vwgt,
-                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
-                      reinterpret_cast<const __m128i*>(&ip[j])))),
-                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch(
-              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
-        }
-        for (; j < block_size; j++) {
-          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
-        }
-      }
-      if (normalize_by_lengths && length) {
-        float len_inv = 1.0f / length;
-        __m256 vlen_inv = _mm256_set1_ps(len_inv);
-        j = 0;
-        for (; j + 8 <= block_size; j += 8) {
-          _mm256_storeu_ps(
-              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
-        }
-        for (; j < block_size; j++) {
-          op[j] = len_inv * op[j];
-        }
-      }
-    }
-  }
-  return dataInd == index_size;
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<false>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-bool Fused8BitRowwiseEmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const uint8_t* input,
-    const int64_t* indices,
-    const int64_t* offsets,
-    const float* weights,
-    bool normalize_by_lengths,
-    float* out) {
-  return Fused8BitRowwiseEmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<true>(
-      block_size,
-      output_size,
-      index_size,
-      data_size,
-      input,
-      indices,
-      offsets,
-      weights,
-      normalize_by_lengths,
-      out);
-}
-
-} // namespace caffe2
diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc
index 1b49f01..825251c 100644
--- a/caffe2/perfkernels/embedding_lookup_idx.cc
+++ b/caffe2/perfkernels/embedding_lookup_idx.cc
@@ -75,7 +75,7 @@
 }
 
 // Proxy back to generic implementation
-#define EMBEDDING_IDX_SPECIALIZATION(                                                                 \
+#define EMBEDDING_SPECIALIZATION(                                                                     \
     IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL)                                     \
   bool                                                                                                \
       EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base(     \
@@ -209,20 +209,20 @@
         "the size of the indices tensor, but it appears not.");                                       \
   }
 
-EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
-EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
-EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
+EMBEDDING_SPECIALIZATION(int32_t, float, float, float, false);
+EMBEDDING_SPECIALIZATION(int64_t, float, float, float, false);
+EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, false);
+EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, false);
+EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
+EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
 
-EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
-EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
-EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
-EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
+EMBEDDING_SPECIALIZATION(int32_t, float, float, float, true);
+EMBEDDING_SPECIALIZATION(int64_t, float, float, float, true);
+EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, true);
+EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, true);
+EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
+EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
 
-#undef EMBEDDING_IDX_SPECIALIZATION
+#undef EMBEDDING_SPECIALIZATION
 
 } // namespace caffe2
diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc
deleted file mode 100644
index 65875e6..0000000
--- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc
+++ /dev/null
@@ -1,212 +0,0 @@
-#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h"
-
-#include "caffe2/core/types.h"
-#include "caffe2/perfkernels/common.h"
-#include "caffe2/utils/cpuid.h"
-
-namespace caffe2 {
-
-/**
- * Base implementation does runtime dispatch for each segment of reduction
- * @return false if there is an out-of-bound error
- */
-template <
-    typename IndexType,
-    typename InType,
-    typename OutType,
-    bool IS_WEIGHT_POSITIONAL = false>
-static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
-    const int64_t block_size,
-    const int64_t output_size,
-    const int64_t index_size,
-    const int64_t data_size,
-    const InType* input,
-    const IndexType* indices,
-    const int64_t* offsets,
-    const float* weights, // optional, can be null for sum reducer
-    bool normalize_by_lengths,
-    OutType* out) {
-  // block_size is the number of elements and fused_block_size is the size of
-  // an entire row, including scale and bias.
-  const auto scale_bias_offset = 8 / sizeof(InType);
-  const int64_t fused_block_size = block_size + scale_bias_offset;
-  int64_t current = 0;
-  for (int m = 0; m < output_size; ++m) {
-    memset(out, 0, sizeof(OutType) * block_size);
-    if (current != offsets[m]) {
-      return false;
-    }
-    int64_t start_offset = offsets[m];
-    int64_t end_offset = (m == output_size - 1 ? index_size : offsets[m + 1]);
-    int64_t length = end_offset - start_offset;
-    for (int i = start_offset; i < end_offset; ++i) {
-      int64_t idx = indices[current];
-      if (idx < 0 || idx >= data_size) {
-        return false;
-      }
-#ifdef __GNUC__
-      if (current + 1 < index_size) {
-        __builtin_prefetch(
-            input + fused_block_size * indices[current + 1], 0, 1);
-      }
-#endif // __GNUC__
-
-      const float* scale_bias = reinterpret_cast<const float*>(
-          input + fused_block_size * indices[current] + block_size);
-
-      float weight = 1.0f;
-      if (weights) {
-        weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
-      }
-      const float scale = weight * scale_bias[0];
-      const float bias = weight * scale_bias[1];
-
-      for (int j = 0; j < block_size; ++j) {
-        out[j] += scale * input[fused_block_size * indices[current] + j] + bias;
-      }
-
-      ++current;
-    }
-    if (normalize_by_lengths && length) {
-      float scale = 1.f / length;
-      for (int j = 0; j < block_size; ++j) {
-        out[j] *= scale;
-      }
-    }
-    out += block_size;
-  }
-  return current == index_size;
-}
-
-// Proxy back to generic implementation
-#define FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(IndexType, OutType)                 \
-  bool                                                                                      \
-      Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base(     \
-          const int64_t block_size,                                                         \
-          const int64_t output_size,                                                        \
-          const int64_t index_size,                                                         \
-          const int64_t data_size,                                                          \
-          const uint8_t* input,                                                             \
-          const IndexType* indices,                                                         \
-          const int64_t* offsets,                                                           \
-          const float* weights,                                                             \
-          bool normalize_by_lengths,                                                        \
-          OutType* out) {                                                                   \
-    return Fused8BitRowwiseEmbeddingLookupGenericSlowIdx<                                   \
-        IndexType,                                                                          \
-        uint8_t,                                                                            \
-        OutType,                                                                            \
-        false>(                                                                             \
-        block_size,                                                                         \
-        output_size,                                                                        \
-        index_size,                                                                         \
-        data_size,                                                                          \
-        input,                                                                              \
-        indices,                                                                            \
-        offsets,                                                                            \
-        weights,                                                                            \
-        normalize_by_lengths,                                                               \
-        out);                                                                               \
-  }                                                                                         \
-  decltype(                                                                                 \
-      Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base)     \
-      Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \
-  bool Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType(                  \
-      const int64_t block_size,                                                             \
-      const int64_t output_size,                                                            \
-      const int64_t index_size,                                                             \
-      const int64_t data_size,                                                              \
-      const uint8_t* input,                                                                 \
-      const IndexType* indices,                                                             \
-      const int64_t* offsets,                                                               \
-      const float* weights,                                                                 \
-      bool normalize_by_lengths,                                                            \
-      OutType* out) {                                                                       \
-    const int32_t one = 1;                                                                  \
-    CAFFE_ENFORCE_EQ(                                                                       \
-        reinterpret_cast<const uint8_t*>(&one)[0],                                          \
-        1,                                                                                  \
-        "Fused8BitRowwiseEmbeddingLookup is not supported on this platform");               \
-    AVX2_FMA_DO(                                                                            \
-        Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false,         \
-        block_size,                                                                         \
-        output_size,                                                                        \
-        index_size,                                                                         \
-        data_size,                                                                          \
-        input,                                                                              \
-        indices,                                                                            \
-        offsets,                                                                            \
-        weights,                                                                            \
-        normalize_by_lengths,                                                               \
-        out);                                                                               \
-    BASE_DO(                                                                                \
-        Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false,         \
-        block_size,                                                                         \
-        output_size,                                                                        \
-        index_size,                                                                         \
-        data_size,                                                                          \
-        input,                                                                              \
-        indices,                                                                            \
-        offsets,                                                                            \
-        weights,                                                                            \
-        normalize_by_lengths,                                                               \
-        out);                                                                               \
-  }                                                                                         \
-  template <>                                                                               \
-  void Fused8BitRowwiseEmbeddingLookupIdx<IndexType, uint8_t, OutType, false>(              \
-      const int64_t block_size,                                                             \
-      const int64_t output_size,                                                            \
-      const int64_t index_size,                                                             \
-      const int64_t data_size,                                                              \
-      const uint8_t* input,                                                                 \
-      const IndexType* indices,                                                             \
-      const int64_t* offsets,                                                               \
-      const float* weights,                                                                 \
-      bool normalize_by_lengths,                                                            \
-      OutType* out) {                                                                       \
-    bool success =                                                                          \
-        Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType(                 \
-            block_size,                                                                     \
-            output_size,                                                                    \
-            index_size,                                                                     \
-            data_size,                                                                      \
-            input,                                                                          \
-            indices,                                                                        \
-            offsets,                                                                        \
-            weights,                                                                        \
-            normalize_by_lengths,                                                           \
-            out);                                                                           \
-    if (success) {                                                                          \
-      return;                                                                               \
-    }                                                                                       \
-    int64_t current = 0;                                                                    \
-    for (int m = 0; m < output_size; ++m) {                                                 \
-      for (int64_t i = offsets[m];                                                          \
-           i < (m == output_size - 1 ? index_size : offsets[m + 1]);                        \
-           ++i) {                                                                           \
-        CAFFE_ENFORCE_LT(current, index_size);                                              \
-        IndexType idx = indices[current];                                                   \
-        CAFFE_ENFORCE(                                                                      \
-            0 <= idx && idx < data_size,                                                    \
-            "Index ",                                                                       \
-            current,                                                                        \
-            " is out of bounds: ",                                                          \
-            idx,                                                                            \
-            ", range 0 to ",                                                                \
-            data_size);                                                                     \
-        ++current;                                                                          \
-      }                                                                                     \
-    }                                                                                       \
-    CAFFE_ENFORCE_EQ(                                                                       \
-        current,                                                                            \
-        index_size,                                                                         \
-        "Your input seems to be incorrect: the sum of lengths values should be "            \
-        "the size of the indices tensor, but it appears not.");                             \
-  }
-
-FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int32_t, float);
-FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int64_t, float);
-
-#undef FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION
-
-} // namespace caffe2
diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h
deleted file mode 100644
index 64152ef..0000000
--- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h
+++ /dev/null
@@ -1,57 +0,0 @@
-#pragma once
-
-#include <cstdint>
-
-namespace caffe2 {
-
-/**
- * Embedding lookup with reduction.
- *
- * `input` of size data_size * (block_size + 8B)
- * `indices` of size index_size
- * `offsets` of size output_size
- * `weights` nullptr or array of size index_size
- * `out` of size output_size * block_size
- *
- * Note that block_size should be the number of quantized values per row in the
- * data, i.e. excluding the scale and bias. The total (fused) block size is
- * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias.
- *
- * Behavior is roughly equivalent to pseudocode:
- *
- * pos = 0
- * fused_block_size = block_size + 8B // quantized values and scale and bias
- * for (i = 0..index_size-1)
- *   for (k = 0..block_size-1)
- *     out[i*block_size + k] = 0
- *   start_offset = offsets[i]
- *   end_offset = i == output_size-1 ? index_size : offsets[i+1] - 1
- *   length = end_offset - start_offset
- *   for (j = start_offset..end_offset)
- *     for (k = 0..block_size-1)
- *       out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] *
- *           (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0)
- *     pos += 1
- *   if (normalize_weights && length > 0)
- *     for (k = 0..block_size-1)
- *       out[i*block_size + k] /= length
- *
- */
-
-template <
-    typename IndexType,
-    typename InType,
-    typename OutType,
-    bool IS_WEIGHT_POSITIONAL = false>
-void Fused8BitRowwiseEmbeddingLookupIdx(
-    const std::int64_t block_size,
-    const std::int64_t output_size,
-    const std::int64_t index_size,
-    const std::int64_t data_size,
-    const InType* input,
-    const IndexType* indices,
-    const int64_t* offsets,
-    const float* weights, // optional, can be null for non-weighted sum
-    bool normalize_by_lengths,
-    OutType* out);
-} // namespace caffe2