blob: 5a6643c2aa67aee4c5a9742951c02c805ca0e411 [file] [log] [blame]
#include "caffe2/opt/optimize_ideep.h"
#include "caffe2/opt/converter.h"
#include "caffe2/opt/fusion.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace opt {
using namespace nom;
void OptimizeForIdeep(repr::NNModule* nn) {
// Conv+Relu fusion
auto should_fuse = [](const repr::Conv& conv) {
const auto annotation = conv.getAnnotation();
if (!annotation || !isa<Caffe2Annotation>(annotation)) {
return false;
}
const auto& op = dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
// We only want to fuse for IDEEP convs
if (op.device_option().device_type() != DeviceType::IDEEP) {
return false;
}
// IDEEP doesn't support fusion group conv
int group =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, "group", 1);
if (group != 1) {
return false;
}
return true;
};
auto postprocess = [](repr::NNGraph::NodeRef conv_node) {
auto conv = repr::nn::get<repr::Conv>(conv_node);
auto annotation = conv->getMutableAnnotation();
if (!annotation || !isa<Caffe2Annotation>(annotation)) {
return;
}
auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
op->set_type("ConvFusion");
auto* arg = op->add_arg();
arg->set_name("fusion_type");
// 1 means FUSION_CONV_RELU
arg->set_i(1);
};
fuseActivation<repr::Conv, repr::Relu>(nn, should_fuse, postprocess);
}
} // namespace opt
} // namespace caffe2