blob: 71297a44d4ff094e9c2e6c3ebf1c349f60fe2206 [file] [log] [blame]
#include "caffe2/opt/custom/fakefp16_transform.h"
#include "caffe2/opt/custom/glow_net_transform.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace opt {
void fakeFp16Transform(NetDef* net) {
static const std::unordered_map<std::string, std::string>
kFakeFp16OpConversionMap = {
{"FC", "Fp16FCAcc16NNPI"},
{"SparseLengthsSum", "SparseLengthsSumFakeFP16AccFP16"},
{"SparseLengthsWeightedSum",
"SparseLengthsWeightedSumFakeFP16AccFP16"},
{"SparseLengthsMean", "SparseLengthsMeanFakeFP16AccFP16"},
{"SparseLengthsSumFused8BitRowwise",
"SparseLengthsSumFused8BitRowwiseFakeFP16AccFP16"},
{"SparseLengthsWeightedSumFused8BitRowwise",
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP16AccFP16"},
{"SparseLengthsMeanFused8BitRowwise",
"SparseLengthsMeanFused8BitRowwiseFakeFP16AccFP16"},
{"BatchMatMul", "BatchMatMulFP16Acc16Fake"},
{"Sigmoid", "SigmoidFakeFp16"},
{"Tanh", "TanhFakeFp16"},
{"Relu", "ReluFakeFp16"},
{"Add", "AddFakeFp16"},
{"Sub", "SubFakeFp16"},
{"Mul", "MulFakeFp16"},
{"Div", "DivFakeFp16"},
{"Sum", "SumFakeFp16"},
{"Sqr", "SqrFakeFp16"},
{"LengthsSum", "LengthsSumFakeFp16"}};
auto blacklist_pos = glow::ParseNetPositionList(FLAGS_onnxifi_blacklist);
auto blacklist_type = glow::ParseBlackListOps(FLAGS_onnxifi_blacklist_ops);
// A hack to only do fakefp16 transformation for operators which will be
// lowered to ONNXIFI.
// TODO(yingz): Use more deterministic logics to figure out operators which
// can be lowered to ONNXIFI instead.
int last_clip_idx = -1;
for (int i = 0; i < net->op().size(); ++i) {
const auto& op = net->op(i);
if (op.type() == "Clip") {
last_clip_idx = i;
}
}
for (int i = 0; i < net->op().size(); ++i) {
if (i <= last_clip_idx) {
continue;
}
auto* op = net->mutable_op(i);
auto net_pos =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*op, "net_pos", -1);
if (blacklist_pos.count(net_pos) || blacklist_type.count(op->type())) {
continue;
}
auto it = kFakeFp16OpConversionMap.find(op->type());
if (it != kFakeFp16OpConversionMap.end()) {
op->set_type(it->second);
}
}
}
} // namespace opt
} // namespace caffe2