blob: 0578a4a5876596c63cbe1bb15b94b0e399597327 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import sys
sizeof = {'float': 4, 'float16': 2, 'uint8_t': 1}
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
def compute(regid, InType, use_weights, isa, prefetch):
code = []
if InType == "float":
code.append(
"vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_loadu_ps(ip + (%d)), vop%d);"
% (regid, regid, regid)
)
elif InType == "float16":
code.append(
"vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
vop%d);"
% (regid, regid, regid)
)
elif InType == "uint8_t":
code.append(
"vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))), \
_mm256_add_ps(vop%d, vbio));"
% (regid, regid, regid)
)
else:
assert False
if prefetch:
code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
else:
code.append("// skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid))
return code
code = []
code.append("// unrolling " + str(uf) + " times")
code.append(IndexType + " dataInd = 0;")
code.append("for (" + IndexType +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
code.append(OutType + " *op = &out[rangeIndex * block_size];")
for i in range(0, uf):
j = 8 * i
code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();")
# inner loop
code.append("for (" + IndexType +
" start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
code.append("const " + IndexType + " idx = indices[dataInd];")
code.append(
'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, "'
' is out of bounds: ", idx, ", range 0 to ", data_size);')
if InType == "uint8_t":
code.append(OutType + " wgt = 1.f;")
code.append(OutType + " bio;")
code.append("if (weights) {")
code.append(
"wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
code.append("}")
if fused:
code.append(
'const float* scale_bias = reinterpret_cast<'
'const float*>(&input[idx * fused_block_size + block_size]);'
)
code.append("bio = wgt * scale_bias[1];")
code.append("wgt = wgt * scale_bias[0];")
else:
code.append("bio = wgt * scale_bias[2 * idx + 1];")
code.append("wgt = wgt * scale_bias[2 * idx];")
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
code.append("if (weights) {")
code.append(
"wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
code.append(
'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
)
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
code.append(
'const {} *ip_next_T0 = &input[idx_pref_T0'
' * fused_block_size];'.format(InType)
)
for i in range(0, uf):
j = 8 * i
cachelinesize = 64
byteoffset = sizeof[InType] * j
prefetch = (byteoffset % cachelinesize) == 0
code.extend(compute(j, InType, use_weights, isa, prefetch))
code.append("}")
code.append("if (normalize_by_lengths == false) {")
for i in range(0, uf):
j = 8 * i
code.append(
"_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
code.append("} else if (lengths[rangeIndex]) {")
# inv of length
code.append(
"__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
for i in range(0, uf):
j = 8 * i
code.append(
"_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));")
code.append("}")
code.append("}")
return code
def generic(IndexType, InType, OutType, use_weights, isa, fused):
def compute(InType, use_weights, isa):
code = []
if InType == "float":
code.append(
"_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
);"
)
elif InType == "float16":
code.append(
"_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
);"
)
elif InType == "uint8_t":
code.append(
"_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(&ip[j])))), \
_mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \
);"
)
else:
assert False
code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
return code
code = []
code.append(IndexType + " dataInd = 0;")
code.append("for (" + IndexType +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
code.append(OutType + " *op = &out[rangeIndex * block_size];")
# initialize to 0
code.append("TIndex j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());")
code.append("}")
code.append("for(; j < block_size; j++) {")
code.append("op[j] = 0.0f;")
code.append("}")
# inner loop
code.append("for (" + IndexType +
" start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
code.append("const " + IndexType + " idx = indices[dataInd];")
code.append(
'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, "' +
' is out of bounds: ", idx, ", range 0 to ", data_size);')
if InType == "uint8_t":
code.append(OutType + " wgt = 1.f;")
code.append(OutType + " bio;")
code.append("if (weights) {")
code.append(
"wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
code.append("}")
if fused:
code.append(
'const float* scale_bias = reinterpret_cast<'
'const float*>(&input[idx * fused_block_size + block_size]);'
)
code.append("bio = wgt * scale_bias[1];")
code.append("wgt = wgt * scale_bias[0];")
else:
code.append("assert (scale_bias);")
code.append("bio = wgt * scale_bias[2 * idx + 1];")
code.append("wgt = wgt * scale_bias[2 * idx];")
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
code.append("if (weights) {")
code.append(
"wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
code.append(
'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
)
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
code.append(
"const {} *ip_next_T0 = &input[idx_pref_T0 * fused_block_size];".
format(InType)
)
# compute and store main loop
code.append("j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.extend(compute(InType, use_weights, isa))
code.append("}")
# leftover
if InType == "float16":
code.append("float16 vtmp1[8] CAFFE2_ALIGNED(64);")
code.append("for(; j < block_size; j++) {")
if InType == "float":
code.append("op[j] += wgt * ip[j];")
elif InType == "float16":
code.append("vtmp1[0] = ip[j];")
code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
code.append("op[j] += wgt * ((float*)(&vtmp2))[0];")
elif InType == "uint8_t":
code.append("op[j] += wgt * ((float)ip[j]) + bio;")
else:
assert False
code.append("}")
code.append("}")
code.append("if (normalize_by_lengths && lengths[rangeIndex]) {")
code.append("float len_inv = 1.0f / lengths[rangeIndex];")
code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);")
code.append("j = 0;")
code.append("for(; j + 8 <= block_size; j += 8) {")
code.append(
"_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));")
code.append("}")
code.append("for(; j < block_size; j++) {")
code.append("op[j] = len_inv * op[j];")
code.append("}")
code.append("}")
code.append("}")
return code
# start main code
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--filename', help="file name")
parser.add_argument('--fused', action='store_true')
opts = parser.parse_args()
if opts.filename:
filename = opts.filename
elif opts.fused:
filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
else:
filename = "embedding_lookup_avx2.cc"
fout = open(filename, 'w')
options = [["int32_t", "float", "float"],
["int64_t", "float", "float"],
["int32_t", "float16", "float"],
["int64_t", "float16", "float"],
["int32_t", "uint8_t", "float"],
["int64_t", "uint8_t", "float"]]
code = []
# includes
code.append("//// --------------------------")
code.append("//// ATTENTION:")
code.append("//// THIS CODE IS AUTOGENERATED")
code.append("//// BY {}".format(sys.argv[0]))
code.append("//// DO NOT MODIFY!!!")
code.append("//// --------------------------\n\n")
code.append("#include <caffe2/core/types.h>")
code.append("#include <caffe2/core/common.h>")
code.append("#include <immintrin.h>")
code.append("\n")
code.append("namespace caffe2 {\n")
for o in options:
[IndexType, InType, OutType] = o
prefix = 'Fused8BitRowwise' if opts.fused else ''
code.append('template <bool IS_WEIGHT_POSITIONAL>')
fn_base = '{}EmbeddingLookup_{}_{}_{}'.format(
prefix, IndexType, InType, OutType
)
suffix = '__avx2_fma'
fn = "static void " + fn_base + suffix
code.append(fn + "(")
args = []
args.append("const TIndex block_size,")
args.append("const TIndex output_size,")
args.append("const TIndex index_size,")
args.append("const TIndex data_size,")
args.append("const " + InType + "* input,")
args.append("const " + IndexType + "* indices,")
args.append("const int* lengths,")
args.append("const float* weights,")
if not opts.fused:
args.append("const float* scale_bias,")
args.append("bool normalize_by_lengths,")
args.append(OutType + "* out)")
code += args
code.append("{")
code.append("const " + IndexType + " prefdist_T0 = 16;")
# block_size is the number of elements and fused_block_size is the size of
# an entire row, including scale and bias.
offset = (8 // sizeof[InType]) if opts.fused else 0
code.append(
"const {} fused_block_size = block_size + {};".
format(IndexType, offset)
)
#code.append("printf(\"calling " + fn + "\\n\");");
if not opts.fused:
if InType != "uint8_t":
code.append(
'CAFFE_ENFORCE(scale_bias == nullptr,'
' "scale_bias must be nullptr");'
)
else:
code.append(
'CAFFE_ENFORCE(scale_bias != nullptr,'
' "scale_bias must not be nullptr");'
)
code.append("if (block_size == 128) {")
code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 64) {")
code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 32) {")
code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 16) {")
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else {")
code.append("// generic code")
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("}")
code.append("}")
for is_weight_positional in ['false', 'true']:
code.append(
"void " + fn_base + "_" + is_weight_positional + suffix + "(")
code += args
code.append("{")
code.append(fn_base + suffix + "<" + is_weight_positional + ">(")
code.append("block_size,")
code.append("output_size,")
code.append("index_size,")
code.append("data_size,")
code.append("input,")
code.append("indices,")
code.append("lengths,")
code.append("weights,")
if not opts.fused:
code.append("scale_bias,")
code.append("normalize_by_lengths,")
code.append("out);")
code.append("}")
code.append("\n")
code.append("} // namespace caffe2")
for c in code:
#print(c, file = fout)
fout.write(c + "\n")
fout.close()
print("Created " + filename)