| //// -------------------------- | 
 | //// ATTENTION: | 
 | //// THIS CODE IS AUTOGENERATED | 
 | //// BY hp_emblookup_codegen.py | 
 | //// DO NOT MODIFY!!! | 
 | //// -------------------------- | 
 |  | 
 | #include <c10/util/Half.h> | 
 | #include <immintrin.h> | 
 | namespace caffe2 { | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int32_t_float_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int prefdist_T0 = 16; | 
 |   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |   const int fused_block_size = block_size + 0; | 
 |   int dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0); | 
 |         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0); | 
 |         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0); | 
 |         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0); | 
 |         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0); | 
 |         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = std::fma(wgt, ip[j], op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int32_t_float_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_float_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int32_t_float_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_float_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int64_t_float_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int64_t prefdist_T0 = 16; | 
 |   const int64_t fused_block_size = block_size + 0; | 
 |   int64_t dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0); | 
 |         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0); | 
 |         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0); | 
 |         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0); | 
 |         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0); | 
 |         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0); | 
 |         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const float* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = std::fma(wgt, ip[j], op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int64_t_float_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_float_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int64_t_float_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const float* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_float_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int32_t_half_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int prefdist_T0 = 16; | 
 |   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |   const int fused_block_size = block_size + 0; | 
 |   int dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), | 
 |             vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), | 
 |             vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), | 
 |             vop48); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), | 
 |             vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), | 
 |             vop64); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), | 
 |             vop72); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), | 
 |             vop80); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[80]) | 
 |         vop88 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), | 
 |             vop88); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), | 
 |             vop96); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0); | 
 |         vop104 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), | 
 |             vop104); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), | 
 |             vop112); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[112]) | 
 |         vop120 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), | 
 |             vop120); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), | 
 |             vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), | 
 |             vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), | 
 |             vop48); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), | 
 |             vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers) | 
 |     alignas(64) at::Half vtmp1[8] = {0}; | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, | 
 |                   _mm256_cvtph_ps(_mm_loadu_si128( | 
 |                       reinterpret_cast<const __m128i*>(&ip[j]))), | 
 |                   _mm256_loadu_ps(&op[j]))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           vtmp1[0] = ip[j]; | 
 |           __m256 vtmp2 = | 
 |               _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1))); | 
 |           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int32_t_half_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_half_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int32_t_half_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_half_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int64_t_half_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int64_t prefdist_T0 = 16; | 
 |   const int64_t fused_block_size = block_size + 0; | 
 |   int64_t dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), | 
 |             vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), | 
 |             vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), | 
 |             vop48); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), | 
 |             vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), | 
 |             vop64); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), | 
 |             vop72); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), | 
 |             vop80); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[80]) | 
 |         vop88 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), | 
 |             vop88); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), | 
 |             vop96); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0); | 
 |         vop104 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), | 
 |             vop104); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), | 
 |             vop112); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[112]) | 
 |         vop120 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), | 
 |             vop120); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), | 
 |             vop32); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0); | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), | 
 |             vop40); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), | 
 |             vop48); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), | 
 |             vop56); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), | 
 |             vop16); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), | 
 |             vop24); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), | 
 |             vop0); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtph_ps( | 
 |                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), | 
 |             vop8); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers) | 
 |     alignas(64) at::Half vtmp1[8] = {0}; | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const at::Half* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, | 
 |                   _mm256_cvtph_ps(_mm_loadu_si128( | 
 |                       reinterpret_cast<const __m128i*>(&ip[j]))), | 
 |                   _mm256_loadu_ps(&op[j]))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           vtmp1[0] = ip[j]; | 
 |           __m256 vtmp2 = | 
 |               _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1))); | 
 |           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int64_t_half_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_half_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int64_t_half_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const at::Half* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_half_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int prefdist_T0 = 16; | 
 |   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |   const int fused_block_size = block_size + 0; | 
 |   int dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), | 
 |             _mm256_add_ps(vop32, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[32]) | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), | 
 |             _mm256_add_ps(vop40, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), | 
 |             _mm256_add_ps(vop48, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), | 
 |             _mm256_add_ps(vop56, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))), | 
 |             _mm256_add_ps(vop64, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))), | 
 |             _mm256_add_ps(vop72, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))), | 
 |             _mm256_add_ps(vop80, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[80]) | 
 |         vop88 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))), | 
 |             _mm256_add_ps(vop88, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))), | 
 |             _mm256_add_ps(vop96, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[96]) | 
 |         vop104 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))), | 
 |             _mm256_add_ps(vop104, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))), | 
 |             _mm256_add_ps(vop112, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[112]) | 
 |         vop120 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))), | 
 |             _mm256_add_ps(vop120, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), | 
 |             _mm256_add_ps(vop32, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[32]) | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), | 
 |             _mm256_add_ps(vop40, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), | 
 |             _mm256_add_ps(vop48, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), | 
 |             _mm256_add_ps(vop56, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, | 
 |                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( | 
 |                       reinterpret_cast<const __m128i*>(&ip[j])))), | 
 |                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = std::fma(wgt, (float)ip[j], bio + op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | template <bool IS_WEIGHT_POSITIONAL> | 
 | static bool EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   const int64_t prefdist_T0 = 16; | 
 |   const int64_t fused_block_size = block_size + 0; | 
 |   int64_t dataInd = 0; | 
 |   if (block_size == 128) { | 
 |     // unrolling 16 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       __m256 vop64 = _mm256_setzero_ps(); | 
 |       __m256 vop72 = _mm256_setzero_ps(); | 
 |       __m256 vop80 = _mm256_setzero_ps(); | 
 |       __m256 vop88 = _mm256_setzero_ps(); | 
 |       __m256 vop96 = _mm256_setzero_ps(); | 
 |       __m256 vop104 = _mm256_setzero_ps(); | 
 |       __m256 vop112 = _mm256_setzero_ps(); | 
 |       __m256 vop120 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), | 
 |             _mm256_add_ps(vop32, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[32]) | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), | 
 |             _mm256_add_ps(vop40, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), | 
 |             _mm256_add_ps(vop48, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), | 
 |             _mm256_add_ps(vop56, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |         vop64 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))), | 
 |             _mm256_add_ps(vop64, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0); | 
 |         vop72 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))), | 
 |             _mm256_add_ps(vop72, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[72]) | 
 |         vop80 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))), | 
 |             _mm256_add_ps(vop80, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[80]) | 
 |         vop88 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))), | 
 |             _mm256_add_ps(vop88, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[88]) | 
 |         vop96 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))), | 
 |             _mm256_add_ps(vop96, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[96]) | 
 |         vop104 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))), | 
 |             _mm256_add_ps(vop104, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[104]) | 
 |         vop112 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))), | 
 |             _mm256_add_ps(vop112, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[112]) | 
 |         vop120 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))), | 
 |             _mm256_add_ps(vop120, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[120]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |         _mm256_storeu_ps(&op[64], vop64); | 
 |         _mm256_storeu_ps(&op[72], vop72); | 
 |         _mm256_storeu_ps(&op[80], vop80); | 
 |         _mm256_storeu_ps(&op[88], vop88); | 
 |         _mm256_storeu_ps(&op[96], vop96); | 
 |         _mm256_storeu_ps(&op[104], vop104); | 
 |         _mm256_storeu_ps(&op[112], vop112); | 
 |         _mm256_storeu_ps(&op[120], vop120); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 64) { | 
 |     // unrolling 8 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       __m256 vop32 = _mm256_setzero_ps(); | 
 |       __m256 vop40 = _mm256_setzero_ps(); | 
 |       __m256 vop48 = _mm256_setzero_ps(); | 
 |       __m256 vop56 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |         vop32 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), | 
 |             _mm256_add_ps(vop32, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[32]) | 
 |         vop40 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), | 
 |             _mm256_add_ps(vop40, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[40]) | 
 |         vop48 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), | 
 |             _mm256_add_ps(vop48, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[48]) | 
 |         vop56 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), | 
 |             _mm256_add_ps(vop56, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[56]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |         _mm256_storeu_ps(&op[32], vop32); | 
 |         _mm256_storeu_ps(&op[40], vop40); | 
 |         _mm256_storeu_ps(&op[48], vop48); | 
 |         _mm256_storeu_ps(&op[56], vop56); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 32) { | 
 |     // unrolling 4 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       __m256 vop16 = _mm256_setzero_ps(); | 
 |       __m256 vop24 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |         vop16 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), | 
 |             _mm256_add_ps(vop16, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[16]) | 
 |         vop24 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), | 
 |             _mm256_add_ps(vop24, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[24]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |         _mm256_storeu_ps(&op[16], vop16); | 
 |         _mm256_storeu_ps(&op[24], vop24); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else if (block_size == 16) { | 
 |     // unrolling 2 times | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       __m256 vop0 = _mm256_setzero_ps(); | 
 |       __m256 vop8 = _mm256_setzero_ps(); | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         vop0 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), | 
 |             _mm256_add_ps(vop0, vbio)); | 
 |         _mm_prefetch( | 
 |             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0); | 
 |         vop8 = _mm256_fmadd_ps( | 
 |             vwgt, | 
 |             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( | 
 |                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), | 
 |             _mm256_add_ps(vop8, vbio)); | 
 |         // skip unnecessary prefetch of (&ip_next_T0[8]) | 
 |       } | 
 |       if (!normalize_by_lengths || lengths[rangeIndex] == 0) { | 
 |         _mm256_storeu_ps(&op[0], vop0); | 
 |         _mm256_storeu_ps(&op[8], vop8); | 
 |       } else { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); | 
 |         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); | 
 |         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); | 
 |       } | 
 |     } | 
 |   } else { | 
 |     // generic code | 
 |     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { | 
 |       float* op = &out[rangeIndex * block_size]; | 
 |       int64_t j = 0; | 
 |       for (; j + 8 <= block_size; j += 8) { | 
 |         _mm256_storeu_ps(op + j, _mm256_setzero_ps()); | 
 |       } | 
 |       for (; j < block_size; j++) { | 
 |         op[j] = 0.0f; | 
 |       } | 
 |       if (dataInd + lengths[rangeIndex] > index_size) { | 
 |         return false; | 
 |       } | 
 |       for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; | 
 |            ++dataInd) { | 
 |         const int64_t idx = indices[dataInd]; | 
 |         if (idx < 0 || idx >= data_size) { | 
 |           return false; | 
 |         } | 
 |         float wgt = 1.f; | 
 |         // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |         float bio; | 
 |         if (weights) { | 
 |           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; | 
 |         } | 
 |         bio = wgt * scale_bias[2 * idx + 1]; | 
 |         wgt = wgt * scale_bias[2 * idx]; | 
 |         __m256 vbio = _mm256_set1_ps(bio); | 
 |         __m256 vwgt = _mm256_set1_ps(wgt); | 
 |         const uint8_t* ip = &input[idx * fused_block_size]; | 
 |         const int64_t next_T0 = (dataInd < index_size - prefdist_T0) | 
 |             ? (dataInd + prefdist_T0) | 
 |             : dataInd; | 
 |         const int64_t idx_pref_T0 = indices[next_T0]; | 
 |         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { | 
 |           return false; | 
 |         } | 
 |         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], | 
 |               _mm256_fmadd_ps( | 
 |                   vwgt, | 
 |                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( | 
 |                       reinterpret_cast<const __m128i*>(&ip[j])))), | 
 |                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); | 
 |           _mm_prefetch( | 
 |               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = std::fma(wgt, (float)ip[j], bio + op[j]); | 
 |         } | 
 |       } | 
 |       if (normalize_by_lengths && lengths[rangeIndex]) { | 
 |         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | 
 |         float len_inv = 1.0f / lengths[rangeIndex]; | 
 |         __m256 vlen_inv = _mm256_set1_ps(len_inv); | 
 |         j = 0; | 
 |         for (; j + 8 <= block_size; j += 8) { | 
 |           _mm256_storeu_ps( | 
 |               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); | 
 |         } | 
 |         for (; j < block_size; j++) { | 
 |           op[j] = len_inv * op[j]; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |   return dataInd == index_size; | 
 | } | 
 | bool EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 | bool EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( | 
 |     const int64_t block_size, | 
 |     const int64_t output_size, | 
 |     const int64_t index_size, | 
 |     const int64_t data_size, | 
 |     const uint8_t* input, | 
 |     const int64_t* indices, | 
 |     const int* lengths, | 
 |     const float* weights, | 
 |     const float* scale_bias, | 
 |     bool normalize_by_lengths, | 
 |     float* out) { | 
 |   return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>( | 
 |       block_size, | 
 |       output_size, | 
 |       index_size, | 
 |       data_size, | 
 |       input, | 
 |       indices, | 
 |       lengths, | 
 |       weights, | 
 |       scale_bias, | 
 |       normalize_by_lengths, | 
 |       out); | 
 | } | 
 |  | 
 | } // namespace caffe2 |