blob: 341a6488a373e43adcca2c32ed8ada2e3ae57ffe [file] [log] [blame]
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import sys
def unroll(uf, IndexType, InType, OutType, use_weights, isa):
def sizeof(InType):
size = 0
if InType == "float":
size = 4
elif InType == "float16":
size = 2
elif InType == "uint8_t":
size = 1
else:
assert False
return size
def compute(regid, InType, use_weights, isa, prefetch):
code = []
if InType == "float":
code.append("vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid))
elif InType == "float16":
code.append("vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
vop%d);"
% (regid, regid, regid))
elif InType == "uint8_t":
code.append("vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d))))), \
_mm256_add_ps(vop%d, vbio));"
% (regid, regid, regid))
else:
assert False
if prefetch == True:
code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
else:
code.append("// skip unecassery prefetch of (&ip_next_T0[%d])" % (regid))
return code
code = []
code.append("// unrolling " + str(uf) + " times")
code.append(IndexType + " dataInd = 0;")
code.append("for (" + IndexType +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
code.append(OutType + " *op = &out[rangeIndex * block_size];")
for i in range(0, uf):
j = 8 * i
code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();")
# inner loop
code.append("for (" + IndexType +
" start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
code.append("const " + IndexType + " idx = indices[dataInd];")
if InType == "uint8_t":
code.append(OutType + " wgt = 1.f;")
code.append(OutType + " bio;")
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
code.append("bio = wgt * scale_bias[2 * indices[dataInd] + 1];");
code.append("wgt = wgt * scale_bias[2 * indices[dataInd]];");
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
code.append("const " + InType + " *ip = &input[idx * block_size];")
code.append("const " + IndexType +
" next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);")
code.append("const " + InType +
" *ip_next_T0 = &input[idx_pref_T0 * block_size];")
for i in range(0, uf):
j = 8 * i
cachelinesize = 64
byteoffset = sizeof(InType) * j
prefetch = ((byteoffset % cachelinesize) == 0)
code.extend(compute(j, InType, use_weights, isa, prefetch))
code.append("}")
code.append("if (normalize_by_lengths == false) {")
for i in range(0, uf):
j = 8 * i
code.append(
"_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
code.append("} else if (lengths[rangeIndex]) {")
# inv of length
code.append(
"__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
for i in range(0, uf):
j = 8 * i
code.append(
"_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));")
code.append("}")
code.append("}")
return code
def generic(IndexType, InType, OutType, use_weights, isa):
def compute(InType, use_weights, isa):
code = []
if InType == "float":
code.append("_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
);")
elif InType == "float16":
code.append("_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
);")
elif InType == "uint8_t":
code.append("_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j])))), \
_mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \
);")
else:
assert False
code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
return code
code = []
code.append(IndexType + " dataInd = 0;")
code.append("for (" + IndexType +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
code.append(OutType + " *op = &out[rangeIndex * block_size];")
# initialize to 0
code.append("TIndex j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());")
code.append("}")
code.append("for(; j < block_size; j++) {")
code.append("op[j] = 0.0f;")
code.append("}")
# inner loop
code.append("for (" + IndexType +
" start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
code.append("const " + IndexType + " idx = indices[dataInd];")
if InType == "uint8_t":
code.append(OutType + " wgt = 1.f;")
code.append(OutType + " bio;")
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
code.append("assert (scale_bias);")
code.append("bio = wgt * scale_bias[2 * indices[dataInd] + 1];");
code.append("wgt = wgt * scale_bias[2 * indices[dataInd]];");
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
code.append("const " + InType + " *ip = &input[idx * block_size];")
code.append("const " + IndexType +
" next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);")
code.append("const " + InType +
" *ip_next_T0 = &input[idx_pref_T0 * block_size];")
# compute and store main loop
code.append("j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.extend(compute(InType, use_weights, isa))
code.append("}")
# leftover
if InType == "float16":
#code.append("float16 vtmp1[8] __attribute__((aligned(64)));")
code.append("float16 vtmp1[8] CAFFE2_ALIGNED(64);")
code.append("for(; j < block_size; j++) {")
if InType == "float":
code.append("op[j] += wgt * ip[j];")
elif InType == "float16":
code.append("vtmp1[0] = ip[j];")
code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
code.append("op[j] += wgt * ((float*)(&vtmp2))[0];")
elif InType == "uint8_t":
code.append("op[j] += wgt * ((float)ip[j]) + bio;")
else:
assert False
code.append("}")
code.append("}")
code.append("if (normalize_by_lengths && lengths[rangeIndex]) {")
code.append("float len_inv = 1.0f / lengths[rangeIndex];")
code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);")
code.append("j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.append(
"_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));")
code.append("}")
code.append("for(; j < block_size; j++) {")
code.append("op[j] = len_inv * op[j];")
code.append("}")
code.append("}")
code.append("}")
return code
# start main code
parser = argparse.ArgumentParser()
parser.add_argument('-f', nargs=1, help="file name")
opts = parser.parse_args()
filename = "embedding_lookup_avx2.cc"
if opts.f:
filename = (opts.f)[0]
fout = open(filename, 'w')
options = [["int32_t", "float", "float"],
["int64_t", "float", "float"],
["int32_t", "float16", "float"],
["int64_t", "float16", "float"],
["int32_t", "uint8_t", "float"],
["int64_t", "uint8_t", "float"],
]
code = []
# includes
code.append("//// --------------------------")
code.append("//// ATTENTION: ")
code.append("//// THIS CODE IS AUTOGENERATED")
code.append("//// BY %s " % (sys.argv[0]))
code.append("//// DO NOT MODIFY!!! ")
code.append("//// --------------------------\n\n")
code.append("#include \"caffe2/core/types.h\"")
code.append("#include \"caffe2/core/common.h\"")
code.append("#include <immintrin.h>")
code.append("\n")
code.append("namespace caffe2 {\n")
for o in options:
[IndexType, InType, OutType] = o
fn = "void EmbeddingLookup_" + IndexType + \
"_" + InType + "_" + OutType + "__avx2_fma"
code.append(fn + "(")
code.append("const TIndex block_size,")
code.append("const TIndex output_size,")
code.append("const TIndex index_size,")
code.append("const TIndex data_size,")
code.append("const " + InType + "* input,")
code.append("const " + IndexType + "* indices,")
code.append("const int* lengths,")
code.append("const float* weights,")
code.append("const float* scale_bias,")
code.append("bool normalize_by_lengths,")
code.append(OutType + "* out)")
code.append("{")
code.append("const " + IndexType + " prefdist_T0 = 16;")
#code.append("printf(\"calling " + fn + "\\n\");");
if InType != "uint8_t":
code.append("CAFFE_ENFORCE(scale_bias == nullptr, \"scale_bias must be nullptr\");");
else:
code.append("CAFFE_ENFORCE(scale_bias != nullptr, \"scale_bias must not be nullptr\");");
code.append("if (block_size == 128) {")
code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2"))
code.append("} else if (block_size == 64) {")
code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2"))
code.append("} else if (block_size == 32) {")
code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2"))
code.append("} else if (block_size == 16) {")
code.extend(unroll(2, IndexType, InType, OutType, True, "AVX2"))
code.append("} else {")
code.append("// generic code")
code.extend(generic(IndexType, InType, OutType, True, "AVX2"))
code.append("}")
code.append("}")
code.append("\n")
code.append("} // namespace caffe2")
for c in code:
#print(c, file = fout)
fout.write(c + "\n")
fout.close()
print("Created " + filename)