| #include <immintrin.h> |
| #include "caffe2/core/common.h" |
| #include "caffe2/core/types.h" |
| |
| namespace caffe2 { |
| |
| void EmbeddingLookup_int32_t_float_float__avx2_fma( |
| const TIndex block_size, |
| const TIndex output_size, |
| const TIndex index_size, |
| const TIndex data_size, |
| const float* input, |
| const int32_t* indices, |
| const int* lengths, |
| const float* weights, |
| bool normalize_by_lengths, |
| float* out) { |
| const int32_t prefdist_T0 = 16; |
| if (block_size == 128) { |
| // unrolling 16 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| __m256 vop64 = _mm256_setzero_ps(); |
| __m256 vop72 = _mm256_setzero_ps(); |
| __m256 vop80 = _mm256_setzero_ps(); |
| __m256 vop88 = _mm256_setzero_ps(); |
| __m256 vop96 = _mm256_setzero_ps(); |
| __m256 vop104 = _mm256_setzero_ps(); |
| __m256 vop112 = _mm256_setzero_ps(); |
| __m256 vop120 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); |
| _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); |
| vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); |
| _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); |
| vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); |
| _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); |
| vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); |
| _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); |
| vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); |
| _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); |
| vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); |
| _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); |
| vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); |
| _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); |
| vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); |
| _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| _mm256_storeu_ps(&op[64], vop64); |
| _mm256_storeu_ps(&op[72], vop72); |
| _mm256_storeu_ps(&op[80], vop80); |
| _mm256_storeu_ps(&op[88], vop88); |
| _mm256_storeu_ps(&op[96], vop96); |
| _mm256_storeu_ps(&op[104], vop104); |
| _mm256_storeu_ps(&op[112], vop112); |
| _mm256_storeu_ps(&op[120], vop120); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); |
| _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); |
| _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); |
| _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); |
| _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); |
| _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); |
| _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); |
| _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); |
| } |
| } |
| } else if (block_size == 64) { |
| // unrolling 8 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| } |
| } |
| } else if (block_size == 32) { |
| // unrolling 4 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| } |
| } |
| } else { |
| // generic code |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| TIndex j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps(op + j, _mm256_setzero_ps()); |
| } |
| for (; j < block_size; j++) { |
| op[j] = 0.0f; |
| } |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], |
| _mm256_fmadd_ps( |
| vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); |
| _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); |
| } |
| for (; j < block_size; j++) { |
| op[j] += wgt * ip[j]; |
| } |
| } |
| if (normalize_by_lengths && lengths[rangeIndex]) { |
| float len_inv = 1.0f / lengths[rangeIndex]; |
| __m256 vlen_inv = _mm256_set1_ps(len_inv); |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); |
| } |
| for (; j < block_size; j++) { |
| op[j] = len_inv * op[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| void EmbeddingLookup_int64_t_float_float__avx2_fma( |
| const TIndex block_size, |
| const TIndex output_size, |
| const TIndex index_size, |
| const TIndex data_size, |
| const float* input, |
| const int64_t* indices, |
| const int* lengths, |
| const float* weights, |
| bool normalize_by_lengths, |
| float* out) { |
| const int64_t prefdist_T0 = 16; |
| if (block_size == 128) { |
| // unrolling 16 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| __m256 vop64 = _mm256_setzero_ps(); |
| __m256 vop72 = _mm256_setzero_ps(); |
| __m256 vop80 = _mm256_setzero_ps(); |
| __m256 vop88 = _mm256_setzero_ps(); |
| __m256 vop96 = _mm256_setzero_ps(); |
| __m256 vop104 = _mm256_setzero_ps(); |
| __m256 vop112 = _mm256_setzero_ps(); |
| __m256 vop120 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); |
| _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); |
| vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); |
| _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); |
| vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); |
| _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); |
| vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); |
| _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); |
| vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); |
| _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); |
| vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); |
| _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); |
| vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); |
| _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); |
| vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); |
| _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| _mm256_storeu_ps(&op[64], vop64); |
| _mm256_storeu_ps(&op[72], vop72); |
| _mm256_storeu_ps(&op[80], vop80); |
| _mm256_storeu_ps(&op[88], vop88); |
| _mm256_storeu_ps(&op[96], vop96); |
| _mm256_storeu_ps(&op[104], vop104); |
| _mm256_storeu_ps(&op[112], vop112); |
| _mm256_storeu_ps(&op[120], vop120); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); |
| _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); |
| _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); |
| _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); |
| _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); |
| _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); |
| _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); |
| _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); |
| } |
| } |
| } else if (block_size == 64) { |
| // unrolling 8 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| } |
| } |
| } else if (block_size == 32) { |
| // unrolling 4 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| } |
| } |
| } else { |
| // generic code |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| TIndex j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps(op + j, _mm256_setzero_ps()); |
| } |
| for (; j < block_size; j++) { |
| op[j] = 0.0f; |
| } |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], |
| _mm256_fmadd_ps( |
| vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); |
| _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); |
| } |
| for (; j < block_size; j++) { |
| op[j] += wgt * ip[j]; |
| } |
| } |
| if (normalize_by_lengths && lengths[rangeIndex]) { |
| float len_inv = 1.0f / lengths[rangeIndex]; |
| __m256 vlen_inv = _mm256_set1_ps(len_inv); |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); |
| } |
| for (; j < block_size; j++) { |
| op[j] = len_inv * op[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| void EmbeddingLookup_int32_t_float16_float__avx2_fma( |
| const TIndex block_size, |
| const TIndex output_size, |
| const TIndex index_size, |
| const TIndex data_size, |
| const float16* input, |
| const int32_t* indices, |
| const int* lengths, |
| const float* weights, |
| bool normalize_by_lengths, |
| float* out) { |
| const int32_t prefdist_T0 = 16; |
| if (block_size == 128) { |
| // unrolling 16 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| __m256 vop64 = _mm256_setzero_ps(); |
| __m256 vop72 = _mm256_setzero_ps(); |
| __m256 vop80 = _mm256_setzero_ps(); |
| __m256 vop88 = _mm256_setzero_ps(); |
| __m256 vop96 = _mm256_setzero_ps(); |
| __m256 vop104 = _mm256_setzero_ps(); |
| __m256 vop112 = _mm256_setzero_ps(); |
| __m256 vop120 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), |
| vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), |
| vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), |
| vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), |
| vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| vop64 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), |
| vop64); |
| _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); |
| vop72 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), |
| vop72); |
| _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); |
| vop80 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), |
| vop80); |
| _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); |
| vop88 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), |
| vop88); |
| _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); |
| vop96 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), |
| vop96); |
| _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); |
| vop104 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), |
| vop104); |
| _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); |
| vop112 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), |
| vop112); |
| _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); |
| vop120 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), |
| vop120); |
| _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| _mm256_storeu_ps(&op[64], vop64); |
| _mm256_storeu_ps(&op[72], vop72); |
| _mm256_storeu_ps(&op[80], vop80); |
| _mm256_storeu_ps(&op[88], vop88); |
| _mm256_storeu_ps(&op[96], vop96); |
| _mm256_storeu_ps(&op[104], vop104); |
| _mm256_storeu_ps(&op[112], vop112); |
| _mm256_storeu_ps(&op[120], vop120); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); |
| _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); |
| _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); |
| _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); |
| _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); |
| _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); |
| _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); |
| _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); |
| } |
| } |
| } else if (block_size == 64) { |
| // unrolling 8 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), |
| vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), |
| vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), |
| vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), |
| vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| } |
| } |
| } else if (block_size == 32) { |
| // unrolling 4 times |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| } |
| } |
| } else { |
| // generic code |
| int32_t dataInd = 0; |
| for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| TIndex j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps(op + j, _mm256_setzero_ps()); |
| } |
| for (; j < block_size; j++) { |
| op[j] = 0.0f; |
| } |
| for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int32_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int32_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int32_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], |
| _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps(_mm_loadu_si128( |
| reinterpret_cast<const __m128i*>(&ip[j]))), |
| _mm256_loadu_ps(&op[j]))); |
| _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); |
| } |
| float16 vtmp1[8] __attribute__((aligned(64))); |
| for (; j < block_size; j++) { |
| vtmp1[0] = ip[j]; |
| __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); |
| op[j] += wgt * ((float*)(&vtmp2))[0]; |
| } |
| } |
| if (normalize_by_lengths && lengths[rangeIndex]) { |
| float len_inv = 1.0f / lengths[rangeIndex]; |
| __m256 vlen_inv = _mm256_set1_ps(len_inv); |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); |
| } |
| for (; j < block_size; j++) { |
| op[j] = len_inv * op[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| void EmbeddingLookup_int64_t_float16_float__avx2_fma( |
| const TIndex block_size, |
| const TIndex output_size, |
| const TIndex index_size, |
| const TIndex data_size, |
| const float16* input, |
| const int64_t* indices, |
| const int* lengths, |
| const float* weights, |
| bool normalize_by_lengths, |
| float* out) { |
| const int64_t prefdist_T0 = 16; |
| if (block_size == 128) { |
| // unrolling 16 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| __m256 vop64 = _mm256_setzero_ps(); |
| __m256 vop72 = _mm256_setzero_ps(); |
| __m256 vop80 = _mm256_setzero_ps(); |
| __m256 vop88 = _mm256_setzero_ps(); |
| __m256 vop96 = _mm256_setzero_ps(); |
| __m256 vop104 = _mm256_setzero_ps(); |
| __m256 vop112 = _mm256_setzero_ps(); |
| __m256 vop120 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), |
| vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), |
| vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), |
| vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), |
| vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| vop64 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), |
| vop64); |
| _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); |
| vop72 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), |
| vop72); |
| _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); |
| vop80 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), |
| vop80); |
| _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); |
| vop88 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), |
| vop88); |
| _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); |
| vop96 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), |
| vop96); |
| _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); |
| vop104 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), |
| vop104); |
| _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); |
| vop112 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), |
| vop112); |
| _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); |
| vop120 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), |
| vop120); |
| _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| _mm256_storeu_ps(&op[64], vop64); |
| _mm256_storeu_ps(&op[72], vop72); |
| _mm256_storeu_ps(&op[80], vop80); |
| _mm256_storeu_ps(&op[88], vop88); |
| _mm256_storeu_ps(&op[96], vop96); |
| _mm256_storeu_ps(&op[104], vop104); |
| _mm256_storeu_ps(&op[112], vop112); |
| _mm256_storeu_ps(&op[120], vop120); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); |
| _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); |
| _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); |
| _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); |
| _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); |
| _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); |
| _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); |
| _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); |
| } |
| } |
| } else if (block_size == 64) { |
| // unrolling 8 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| __m256 vop32 = _mm256_setzero_ps(); |
| __m256 vop40 = _mm256_setzero_ps(); |
| __m256 vop48 = _mm256_setzero_ps(); |
| __m256 vop56 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| vop32 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), |
| vop32); |
| _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); |
| vop40 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), |
| vop40); |
| _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); |
| vop48 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), |
| vop48); |
| _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); |
| vop56 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), |
| vop56); |
| _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| _mm256_storeu_ps(&op[32], vop32); |
| _mm256_storeu_ps(&op[40], vop40); |
| _mm256_storeu_ps(&op[48], vop48); |
| _mm256_storeu_ps(&op[56], vop56); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); |
| _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); |
| _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); |
| _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); |
| } |
| } |
| } else if (block_size == 32) { |
| // unrolling 4 times |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| __m256 vop0 = _mm256_setzero_ps(); |
| __m256 vop8 = _mm256_setzero_ps(); |
| __m256 vop16 = _mm256_setzero_ps(); |
| __m256 vop24 = _mm256_setzero_ps(); |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| vop0 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), |
| vop0); |
| _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); |
| vop8 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), |
| vop8); |
| _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); |
| vop16 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), |
| vop16); |
| _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); |
| vop24 = _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps( |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), |
| vop24); |
| _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); |
| } |
| if (normalize_by_lengths == false) { |
| _mm256_storeu_ps(&op[0], vop0); |
| _mm256_storeu_ps(&op[8], vop8); |
| _mm256_storeu_ps(&op[16], vop16); |
| _mm256_storeu_ps(&op[24], vop24); |
| } else if (lengths[rangeIndex]) { |
| __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); |
| _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); |
| _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); |
| _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); |
| _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); |
| } |
| } |
| } else { |
| // generic code |
| int64_t dataInd = 0; |
| for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { |
| float* op = &out[rangeIndex * block_size]; |
| TIndex j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps(op + j, _mm256_setzero_ps()); |
| } |
| for (; j < block_size; j++) { |
| op[j] = 0.0f; |
| } |
| for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; |
| ++dataInd) { |
| const int64_t idx = indices[dataInd]; |
| float wgt = 1.f; |
| if (weights) { |
| wgt = weights[dataInd]; |
| } |
| __m256 vwgt = _mm256_set1_ps(wgt); |
| const float16* ip = &input[idx * block_size]; |
| const int64_t next_T0 = (dataInd < index_size - prefdist_T0) |
| ? (dataInd + prefdist_T0) |
| : dataInd; |
| const int64_t idx_pref_T0 = indices[next_T0]; |
| assert( |
| idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && |
| idx_pref_T0 < data_size); |
| const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], |
| _mm256_fmadd_ps( |
| vwgt, |
| _mm256_cvtph_ps(_mm_loadu_si128( |
| reinterpret_cast<const __m128i*>(&ip[j]))), |
| _mm256_loadu_ps(&op[j]))); |
| _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); |
| } |
| float16 vtmp1[8] __attribute__((aligned(64))); |
| for (; j < block_size; j++) { |
| vtmp1[0] = ip[j]; |
| __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); |
| op[j] += wgt * ((float*)(&vtmp2))[0]; |
| } |
| } |
| if (normalize_by_lengths && lengths[rangeIndex]) { |
| float len_inv = 1.0f / lengths[rangeIndex]; |
| __m256 vlen_inv = _mm256_set1_ps(len_inv); |
| j = 0; |
| for (; j + 8 <= block_size; j += 8) { |
| _mm256_storeu_ps( |
| &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); |
| } |
| for (; j < block_size; j++) { |
| op[j] = len_inv * op[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| } // namespace caffe2 |