blob: ec873651166b5bda51aad97762edd0a7e0fbf3b9 [file] [log] [blame]
#include "caffe2/opt/converter.h"
#include <gtest/gtest.h>
#define ADD_ARG(_op, _name, _type, _val) \
{ \
caffe2::Argument *arg = _op->add_arg(); \
arg->set_name(_name); \
arg->set_##_type(_val); \
}
TEST(Converter, Basic) {
caffe2::NetDef net;
for (auto i = 0; i < 10; ++i) {
if (rand() % 2) {
caffe2::OperatorDef *def = net.add_op();
def->set_type("Conv");
def->add_input("X");
def->add_input("W" + c10::to_string(i)); // different weights
ADD_ARG(def, "kernel", i, 3);
ADD_ARG(def, "stride", i, 1);
ADD_ARG(def, "pad", i, 0);
ADD_ARG(def, "order", s, "NCHW");
def->add_output("X");
def->mutable_device_option()->set_node_name("conv_runner");
} else {
caffe2::OperatorDef *def = net.add_op();
def->set_type("Relu");
def->add_input("X");
def->add_output("X");
def->mutable_device_option()->set_node_name("relu_runner");
}
}
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
}
TEST(Converter, UnknownType) {
caffe2::NetDef net;
caffe2::OperatorDef *def = net.add_op();
def->set_type("NeverSeen");
def->add_input("X");
def->add_output("X");
def->mutable_device_option()->set_node_name(
"device_" + c10::to_string(rand() % 2));
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
}
caffe2::NetDef fakeNet() {
caffe2::NetDef net;
{
caffe2::OperatorDef* def = net.add_op();
def->set_type("Fake");
def->add_input("X");
def->add_output("Y");
}
{
caffe2::OperatorDef* def = net.add_op();
def->set_type("Fake");
def->add_input("Y");
def->add_output("Z");
}
{
caffe2::OperatorDef* def = net.add_op();
def->set_type("Fake");
def->add_input("Z");
def->add_input("X");
def->add_output("W");
}
net.add_external_input("X");
net.add_external_output("Y");
net.add_external_output("W");
return net;
}
TEST(Converter, ExternalInputs) {
auto net = fakeNet();
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
EXPECT_EQ(new_netdef.external_input().size(), net.external_input().size());
for (auto i = 0; i < net.external_input().size(); ++i) {
EXPECT_EQ(new_netdef.external_input(i), net.external_input(i));
}
}
TEST(Converter, ExternalOutputs) {
auto net = fakeNet();
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
EXPECT_EQ(new_netdef.external_output().size(), net.external_output().size());
for (auto i = 0; i < net.external_output().size(); ++i) {
EXPECT_EQ(new_netdef.external_output(i), net.external_output(i));
}
}
TEST(Converter, InjectDataEdgeIndicators) {
auto net = fakeNet();
caffe2::injectDataEdgeIndicators(&net);
EXPECT_EQ(net.op_size(), 3 + 1 + 2); // Inserted 1 Declare and 2 Export
auto declare_count = 0;
auto export_count = 0;
for (const auto& op : net.op()) {
declare_count += op.type() == "Declare";
export_count += op.type() == "Export";
}
EXPECT_EQ(declare_count, 1);
EXPECT_EQ(export_count, 2);
// Remove them from the network
EXPECT_EQ(net.external_input_size(), 0);
EXPECT_EQ(net.external_output_size(), 0);
// Ensure nomnigraph can handle this change
auto nn = caffe2::convertToNNModule(net);
auto new_net = caffe2::convertToCaffe2Proto(nn);
caffe2::removeDataEdgeIndicators(&new_net);
for (const auto& op : new_net.op()) {
EXPECT_NE(op.type(), "Declare");
EXPECT_NE(op.type(), "Export");
}
EXPECT_EQ(new_net.external_input_size(), 1);
EXPECT_EQ(new_net.external_output_size(), 2);
}