blob: 5e16a884011fa40389fc931321afbefcade437d3 [file] [log] [blame]
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/opt/backend_cutting.h"
#include "caffe2/utils/string_utils.h"
#include <gtest/gtest.h>
namespace {
using caffe2::StartsWith;
void AddConv(caffe2::NetDef* net, int tick) {
auto* op = net->add_op();
op->set_type("MyConv");
op->add_input("N" + c10::to_string(tick));
op->add_input("W" + c10::to_string(tick));
op->add_input("b" + c10::to_string(tick));
op->add_output("N" + c10::to_string(tick + 1));
}
bool Supports(const caffe2::OperatorDef& op) {
return StartsWith(op.type(), "MyConv") || StartsWith(op.type(), "MyRelu") ||
StartsWith(op.type(), "Concat");
}
caffe2::NetDef Transform(const caffe2::NetDef& net) {
caffe2::NetDef net_opt;
auto* op = net_opt.add_op();
op->set_type("BigOpt");
for (const auto& i : net.external_input()) {
// Absorb the weights and bias
if (!StartsWith(i, "W") && !StartsWith(i, "b")) {
net_opt.add_external_input(i);
op->add_input(i);
}
}
for (const auto& i : net.external_output()) {
net_opt.add_external_output(i);
op->add_output(i);
}
return net_opt;
}
} // namespace
// N0 -> MyConv -> N1
TEST(BackendCuttingTest, unit) {
caffe2::NetDef net;
AddConv(&net, 0);
net.add_external_input("N0");
net.add_external_input("W0");
net.add_external_input("b0");
net.add_external_output("N1");
auto net_opt = caffe2::opt::OptimizeForBackend(net, Supports, Transform);
EXPECT_EQ(1, net_opt.op_size());
EXPECT_EQ(1, net_opt.external_input_size());
EXPECT_EQ(1, net_opt.external_output_size());
}
// X -> CopyIn -> MyConv -> MyConv -> CopyOut -> Y
TEST(BackendCuttingTest, line) {
caffe2::NetDef net;
net.add_external_input("X");
// Adding weights as external intputs to test weight absorption
net.add_external_input("W0");
net.add_external_input("W1");
net.add_external_input("b0");
net.add_external_input("b1");
net.add_external_output("Y");
auto* op = net.add_op();
op->set_type("CopyIn");
op->add_input("X");
op->add_output("N0");
for (int i = 0; i < 2; ++i) {
AddConv(&net, i);
}
op = net.add_op();
op->set_type("CopyOut");
op->add_input("N2");
op->add_output("Y");
auto net_opt = caffe2::opt::OptimizeForBackend(net, Supports, Transform);
EXPECT_EQ(3, net_opt.op_size());
}
// X0 -> CopyIn -> MyConv -|
// > Concat -> CopyOut -> Y
// N2 -> MyConv -> MyRelu -|
TEST(BackendCuttingTest, convergedPaths) {
caffe2::NetDef net;
net.add_external_input("X0");
net.add_external_input("X1");
net.add_external_input("N2");
net.add_external_output("Y");
auto* op = net.add_op();
op->set_type("CopyIn");
op->add_input("X0");
op->add_output("N0");
AddConv(&net, 0);
AddConv(&net, 2);
op = net.add_op();
op->set_type("MyRelu");
op->add_input("N3");
op->add_output("N4");
op = net.add_op();
op->set_type("Concat");
op->add_input("X1");
op->add_input("N1");
op->add_input("N4");
op->add_output("N5");
op = net.add_op();
op->set_type("CopyOut");
op->add_input("N5");
op->add_output("Y");
auto net_opt = caffe2::opt::OptimizeForBackend(net, Supports, Transform);
EXPECT_EQ(3, net_opt.op_size());
};
// -> Random -> Relu -> MyConv4
// | |
// N0 -> MyConv -> MyRelu -> MyConv2 ----------> Concat -> CopyOut -> Y
TEST(BackendCuttingTest, skipPath) {
caffe2::NetDef net;
net.add_external_input("N0");
net.add_external_output("Y");
AddConv(&net, 0);
auto* op = net.add_op();
op->set_type("MyRelu");
op->add_input("N1");
op->add_output("N2");
op = net.add_op();
op->set_type("Random");
op->add_input("N1");
op->add_output("N4");
op = net.add_op();
op->set_type("MyRelu");
op->add_input("N4");
op->add_output("N5");
AddConv(&net, 2);
AddConv(&net, 5);
op = net.add_op();
op->set_type("Concat");
op->add_input("N3");
op->add_input("N6");
op->add_output("N7");
op = net.add_op();
op->set_type("CopyOut");
op->add_input("N7");
op->add_output("Y");
auto net_opt = caffe2::opt::OptimizeForBackend(net, Supports, Transform);
EXPECT_EQ(4, net_opt.op_size());
}