blob: c37fb4e8798c32c5fe5e730bb64cc7b99bd1dbe9 [file] [log] [blame]
#include "caffe2/opt/fakefp16_transform.h"
#include "caffe2/opt/glow_net_transform.h"
#include "caffe2/utils/proto_utils.h"
C10_DEFINE_bool(
fake_fp16_conversion_use_fp16_acc,
false,
"Whether to enable fp16 accumulation for FC / BatchMatMul for fakefp16 "
"operators.");
C10_DEFINE_bool(
fake_fp16_conversion_use_nnpi,
false,
"Whether to simulate NNPI behavior for fakefp16 operators.");
namespace caffe2 {
namespace opt {
std::unordered_map<std::string, std::string> getFakeFp16OpMapping(
bool use_fp16_acc,
bool use_nnpi) {
std::unordered_map<std::string, std::string> fake_fp16_op_conversion_map = {
{"FC", "Fp16FCAcc32NNPI"},
{"Int8FC", "Int8FCFakeAcc32NNPI"},
{"Int8Quantize", "Int8QuantizeNNPI"},
{"Int8Dequantize", "Int8DequantizeNNPI"},
{"LayerNorm", "LayerNormFakeFP16NNPI"},
{"FbFCPacked", "Fp16FCAcc32NNPI"},
{"Logit", "LogitFakeFp16NNPI"},
{"SparseLengthsSum", "SparseLengthsSumFakeFP16AccFP16"},
{"SparseLengthsWeightedSum", "SparseLengthsWeightedSumFakeFP16AccFP16"},
{"SparseLengthsMean", "SparseLengthsMeanFakeFP16AccFP16"},
{"SparseLengthsSumFused4BitRowwise",
"SparseLengthsSumFused4BitRowwiseFakeFP16NNPI"},
{"SparseLengthsWeightedSumFused4BitRowwise",
"SparseLengthsWeightedSumFused4BitRowwiseFakeFP16NNPI"},
{"SparseLengthsSumFused8BitRowwise",
"SparseLengthsSumFused8BitRowwiseFakeFP16NNPI"},
{"SparseLengthsWeightedSumFused8BitRowwise",
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI"},
{"SparseLengthsMeanFused8BitRowwise",
"SparseLengthsMeanFused8BitRowwiseFakeFP16AccFP16"},
{"MatMul", "BatchMatMulFP16Acc32Fake"},
{"BatchMatMul", "BatchMatMulFP16Acc32Fake"},
{"Sigmoid", "SigmoidFakeFp16"},
{"SpatialBN", "SpatialBNFakeFp16NNPI"},
{"Swish", "SwishFakeFp16NNPI"},
{"Tanh", "TanhFakeFp16"},
{"Relu", "ReluFakeFp16"},
{"Add", "AddFakeFp16"},
{"Sub", "SubFakeFp16"},
{"Mul", "MulFakeFp16"},
{"Div", "DivFakeFp16"},
{"Sum", "SumFakeFp16"},
{"Sqr", "SqrFakeFp16"},
{"LengthsSum", "LengthsSumFakeFp16"}};
if (use_fp16_acc) {
fake_fp16_op_conversion_map["FC"] = "Fp16FCAcc16NNPI";
fake_fp16_op_conversion_map["FbFCPacked"] = "Fp16FCAcc16NNPI";
fake_fp16_op_conversion_map["BatchMatMul"] = "BatchMatMulFP16Acc16Fake";
fake_fp16_op_conversion_map["MatMul"] = "BatchMatMulFP16Acc16Fake";
}
if (use_nnpi) {
fake_fp16_op_conversion_map["Sigmoid"] = "SigmoidFakeFp16NNPI";
fake_fp16_op_conversion_map["Tanh"] = "TanhFakeFp16NNPI";
}
return fake_fp16_op_conversion_map;
}
std::vector<OperatorDef*> findMutableOperatorByInput(
NetDef* net,
const std::string& input) {
std::vector<OperatorDef*> ops;
for (auto& op : *net->mutable_op()) {
for (const auto& i : op.input()) {
if (input == i) {
ops.push_back(&op);
}
}
}
return ops;
}
void fakeFp16FoldLayerNorm(NetDef* net) {
for (auto& op : *net->mutable_op()) {
if (op.type() == "LayerNormFakeFP16NNPI") {
LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI at "
<< ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "net_pos", -1);
if (op.input().size() != 1) {
LOG(INFO) << "input isn't 1, skipping";
continue;
}
const std::string& ln_output = op.output(0);
auto next_ops = findMutableOperatorByInput(net, ln_output);
if (next_ops.size() != 1 || next_ops[0]->type() != "MulFakeFp16") {
LOG(INFO) << "next op isn't MulFakeFp16, skipping";
continue;
}
auto* mul_op = next_ops[0];
auto next_next_ops = findMutableOperatorByInput(net, mul_op->output(0));
if (next_next_ops.size() != 1 ||
next_next_ops[0]->type() != "AddFakeFp16") {
LOG(INFO) << "next op isn't AddFakeFp16, skipping";
continue;
}
auto* add_op = next_next_ops[0];
*(op.mutable_input()->Add()) = mul_op->input(1);
*(op.mutable_input()->Add()) = add_op->input(1);
*op.mutable_output(0) = add_op->output(0);
mul_op->set_type("delete_me_optimized_away");
add_op->set_type("delete_me_optimized_away");
LOG(INFO) << "Fused LayerNormFakeFP16NNPI";
}
}
}
void fakeFp16FoldLayerNormQuant(NetDef* net) {
for (auto& op : *net->mutable_op()) {
if (op.type() == "LayerNormFakeFP16NNPI") {
auto layernormNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "net_pos", -1);
LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI w Quant at "
<< layernormNetPos;
if (op.input().size() != 1) {
LOG(INFO) << "input isn't 1, is " << op.input().size() << " skipping";
continue;
}
const std::string& ln_output = op.output(0);
auto next_ops = findMutableOperatorByInput(net, ln_output);
if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
LOG(INFO) << "next op isn't Int8QuantizeNNPI, skipping";
continue;
}
auto* quantOp = next_ops[0];
if (quantOp->output().size() != 1) {
LOG(INFO) << "more than one output for quant, skipping";
continue;
}
op.set_type("LayerNormInt8QuantizeFakeNNPI");
*op.mutable_output(0) = quantOp->output(0);
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
auto quantNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
*quantOp, "net_pos", -1);
quantOp->set_type("delete_me_optimized_away");
LOG(INFO) << "Fused LayerNormFakeFP16NNPI w Quant at " << layernormNetPos << " " << quantNetPos;
}
}
}
void fakeFp16FoldSwish(NetDef* net) {
// find a sequence deq->swish->quant and replace it
for (auto& op : *net->mutable_op()) {
if (op.type() == "Int8DequantizeNNPI") {
auto deq_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "net_pos", -1);
LOG(INFO) << "Attempting swish fusion at " << deq_net_pos;
if (op.output().size() != 1) {
LOG(INFO) << "more than one output deq, skipping";
continue;
}
const std::string& deqOutput = op.output(0);
auto next_ops = findMutableOperatorByInput(net, deqOutput);
if (next_ops.size() != 1 || next_ops[0]->type() != "SwishFakeFp16NNPI") {
LOG(INFO) << "skipping, next op is " << next_ops[0]->type();
continue;
}
auto* swishOp = next_ops[0];
if (swishOp->output().size() != 1) {
LOG(INFO) << "more than one output for swish, skipping";
continue;
}
auto next_next_ops = findMutableOperatorByInput(net, swishOp->output(0));
if (next_next_ops.size() != 1 || next_next_ops[0]->type() != "Int8QuantizeNNPI") {
LOG(INFO) << "skipping, next op isn't quant, is " << next_next_ops[0]->type();
continue;
}
auto* quantOp = next_next_ops[0];
op.set_type("SwishFakeInt8NNPI");
*op.mutable_output(0) = quantOp->output(0);
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
auto swish_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
*swishOp, "net_pos", -1);
auto quant_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
*quantOp, "net_pos", -1);
swishOp->set_type("delete_me_optimized_away");
quantOp->set_type("delete_me_optimized_away");
LOG(INFO) << "Fusing swish at " << deq_net_pos << ", " << swish_net_pos << ", " << quant_net_pos;
}
}
}
void fakeFp16FoldTanhQuant(NetDef* net) {
// find a sequence deq->swish->quant and replace it
for (auto& op : *net->mutable_op()) {
if (op.type() == "TanhFakeFp16NNPI") {
auto tanh_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "net_pos", -1);
LOG(INFO) << "Attempting tanh fusion at " << tanh_net_pos;
if (op.output().size() != 1) {
LOG(INFO) << "more than one output for tanh, skipping";
continue;
}
const std::string& tanhOutput = op.output(0);
auto next_ops = findMutableOperatorByInput(net, tanhOutput);
if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
LOG(INFO) << "skipping, next op is " << next_ops[0]->type();
continue;
}
auto* quantOp = next_ops[0];
if (quantOp->output().size() != 1) {
LOG(INFO) << "more than one output for quant, skipping";
continue;
}
op.set_type("TanhQuantFakeFp16NNPI");
*op.mutable_output(0) = quantOp->output(0);
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
auto quant_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
*quantOp, "net_pos", -1);
quantOp->set_type("delete_me_optimized_away");
LOG(INFO) << "Fusing tanh and quant at " << tanh_net_pos << ", " << quant_net_pos;
}
}
}
void fakeFp16FuseOps(NetDef* net) {
LOG(INFO) << "Running Fp16 Fusion";
// We should fuse the groups of bigger operators first
fakeFp16FoldLayerNorm(net);
fakeFp16FoldSwish(net);
fakeFp16FoldTanhQuant(net);
fakeFp16FoldLayerNormQuant(net);
auto iter = net->mutable_op()->begin();
while (iter != net->mutable_op()->end()) {
if (iter->type() == "delete_me_optimized_away") {
iter = net->mutable_op()->erase(iter);
} else {
++iter;
}
}
}
void fakeFp16Transform(NetDef* net) {
static const std::unordered_map<std::string, std::string>
kFakeFp16OpConversionMap = getFakeFp16OpMapping(
FLAGS_fake_fp16_conversion_use_fp16_acc,
FLAGS_fake_fp16_conversion_use_nnpi);
auto blocklist_pos = glow::ParseNetPositionList(FLAGS_onnxifi_blacklist);
auto blocklist_type = glow::ParseBlockListOps(FLAGS_onnxifi_blacklist_ops);
// A hack to only do fakefp16 transformation for operators which will be
// lowered to ONNXIFI.
// TODO(yingz): Use more deterministic logics to figure out operators which
// can be lowered to ONNXIFI instead.
int last_clip_idx = -1;
for (int i = 0; i < net->op().size(); ++i) {
const auto& op = net->op(i);
if (op.type() == "Clip") {
last_clip_idx = i;
}
}
for (int i = 0; i < net->op().size(); ++i) {
if (i <= last_clip_idx) {
continue;
}
auto* op = net->mutable_op(i);
auto net_pos =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*op, "net_pos", -1);
if (blocklist_pos.count(net_pos) || blocklist_type.count(op->type())) {
continue;
}
auto it = kFakeFp16OpConversionMap.find(op->type());
if (it != kFakeFp16OpConversionMap.end()) {
op->set_type(it->second);
}
}
fakeFp16FuseOps(net);
}
} // namespace opt
} // namespace caffe2