| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| import argparse |
| |
| |
| def unroll(uf, IndexType, InType, OutType, use_weights, isa): |
| |
| def compute(regid, InType, use_weights, isa): |
| code = [] |
| |
| if InType == "float": |
| code.append("vop%d = _mm256_fmadd_ps(vwgt, \ |
| _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid)) |
| else: |
| code.append("vop%d = _mm256_fmadd_ps(vwgt, \ |
| _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \ |
| vop%d);" |
| % (regid, regid, regid)) |
| |
| code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid)) |
| |
| return code |
| |
| code = [] |
| code.append("// unrolling " + str(uf) + " times") |
| code.append(IndexType + " dataInd = 0;") |
| code.append("for (" + IndexType + |
| " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {") |
| code.append(OutType + " *op = &out[rangeIndex * block_size];") |
| for i in range(0, uf): |
| j = 8 * i |
| code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();") |
| |
| # inner loop |
| code.append("for (" + IndexType + |
| " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {") |
| code.append("const " + IndexType + " idx = indices[dataInd];") |
| code.append(OutType + " wgt = 1.f;") |
| code.append("if (weights) {") |
| code.append("wgt = weights[dataInd];") |
| code.append("}") |
| code.append("__m256 vwgt = _mm256_set1_ps(wgt);") |
| code.append("const " + InType + " *ip = &input[idx * block_size];") |
| code.append("const " + IndexType + |
| " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); |
| code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") |
| code.append( |
| "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);") |
| code.append("const " + InType + |
| " *ip_next_T0 = &input[idx_pref_T0 * block_size];") |
| |
| for i in range(0, uf): |
| j = 8 * i |
| code.extend(compute(j, InType, use_weights, isa)) |
| code.append("}") |
| |
| code.append("if (normalize_by_lengths == false) {") |
| for i in range(0, uf): |
| j = 8 * i |
| code.append( |
| "_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");") |
| code.append("} else if (lengths[rangeIndex]) {") |
| # inv of length |
| code.append( |
| "__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") |
| for i in range(0, uf): |
| j = 8 * i |
| code.append( |
| "_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));") |
| code.append("}") |
| |
| code.append("}") |
| return code |
| |
| |
| def generic(IndexType, InType, OutType, use_weights, isa): |
| |
| def compute(InType, use_weights, isa): |
| code = [] |
| if InType == "float": |
| code.append("_mm256_storeu_ps(&op[j], \ |
| _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \ |
| );") |
| else: |
| code.append("_mm256_storeu_ps(&op[j], \ |
| _mm256_fmadd_ps(vwgt, \ |
| _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \ |
| );") |
| |
| code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);") |
| |
| return code |
| |
| code = [] |
| code.append(IndexType + " dataInd = 0;") |
| code.append("for (" + IndexType + |
| " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {") |
| code.append(OutType + " *op = &out[rangeIndex * block_size];") |
| |
| # initialize to 0 |
| code.append("TIndex j = 0;") |
| code.append("for(; j + 8 <= block_size; j += 8) {") |
| code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());") |
| code.append("}") |
| code.append("for(; j < block_size; j++) {") |
| code.append("op[j] = 0.0f;") |
| code.append("}") |
| |
| # inner loop |
| code.append("for (" + IndexType + |
| " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {") |
| code.append("const " + IndexType + " idx = indices[dataInd];") |
| code.append(OutType + " wgt = 1.f;") |
| code.append("if (weights) {") |
| code.append("wgt = weights[dataInd];") |
| code.append("}") |
| code.append("__m256 vwgt = _mm256_set1_ps(wgt);") |
| code.append("const " + InType + " *ip = &input[idx * block_size];") |
| code.append("const " + IndexType + |
| " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); |
| code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") |
| code.append( |
| "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);") |
| code.append("const " + InType + |
| " *ip_next_T0 = &input[idx_pref_T0 * block_size];") |
| |
| # compute and store main loop |
| code.append("j = 0;") |
| code.append("for(; j + 8 <= block_size; j += 8) {") |
| code.extend(compute(InType, use_weights, isa)) |
| code.append("}") |
| # leftover |
| if InType == "float16": |
| code.append("float16 vtmp1[8] __attribute__((aligned(64)));") |
| code.append("for(; j < block_size; j++) {") |
| if InType == "float": |
| code.append("op[j] += wgt * ip[j];") |
| else: |
| code.append("vtmp1[0] = ip[j];") |
| code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));") |
| code.append("op[j] += wgt * ((float*)(&vtmp2))[0];") |
| code.append("}") |
| |
| code.append("}") |
| |
| code.append("if (normalize_by_lengths && lengths[rangeIndex]) {") |
| code.append("float len_inv = 1.0f / lengths[rangeIndex];") |
| code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);") |
| code.append("j = 0;") |
| code.append("for(; j + 8 <= block_size; j += 8) {") |
| code.append( |
| "_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));") |
| code.append("}") |
| code.append("for(; j < block_size; j++) {") |
| code.append("op[j] = len_inv * op[j];") |
| code.append("}") |
| |
| code.append("}") |
| |
| code.append("}") |
| return code |
| |
| |
| |
| # start main code |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-f', nargs=1, help="file name") |
| opts = parser.parse_args() |
| filename = "embedding_lookup_avx2.cc" |
| if opts.f: |
| filename = (opts.f)[0] |
| fout = open(filename, 'w') |
| |
| options = [["int32_t", "float", "float"], |
| ["int64_t", "float", "float"], |
| ["int32_t", "float16", "float"], |
| ["int64_t", "float16", "float"]] |
| |
| code = [] |
| # includes |
| code.append("#include \"caffe2/core/types.h\"") |
| code.append("#include \"caffe2/core/common.h\"") |
| code.append("#include <immintrin.h>") |
| code.append("\n") |
| |
| code.append("namespace caffe2 {\n") |
| for o in options: |
| [IndexType, InType, OutType] = o |
| |
| fn = "void EmbeddingLookup_" + IndexType + \ |
| "_" + InType + "_" + OutType + "__avx2_fma" |
| code.append(fn + "(") |
| code.append("const TIndex block_size,") |
| code.append("const TIndex output_size,") |
| code.append("const TIndex index_size,") |
| code.append("const TIndex data_size,") |
| code.append("const " + InType + "* input,") |
| code.append("const " + IndexType + "* indices,") |
| code.append("const int* lengths,") |
| code.append("const float* weights,") |
| code.append("bool normalize_by_lengths,") |
| code.append(OutType + "* out)") |
| |
| code.append("{") |
| code.append("const " + IndexType + " prefdist_T0 = 16;") |
| #code.append("printf(\"calling " + fn + "\\n\");"); |
| |
| code.append("if (block_size == 128) {") |
| code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2")) |
| code.append("} else if (block_size == 64) {") |
| code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2")) |
| code.append("} else if (block_size == 32) {") |
| code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2")) |
| code.append("} else {") |
| code.append("// generic code") |
| code.extend(generic(IndexType, InType, OutType, True, "AVX2")) |
| code.append("}") |
| |
| code.append("}") |
| |
| code.append("\n") |
| code.append("} // namespace caffe2") |
| |
| for c in code: |
| #print(c, file = fout) |
| fout.write(c + "\n") |
| fout.close() |
| |
| |
| print("Created " + filename) |