blob: 7089269543a0e2b29f271ed5099e6cc685a72533 [file] [log] [blame]
#include "caffe2/core/net.h"
#include "caffe2/core/operator.h"
#include "caffe2/mpi/mpi_common.h"
#include "glog/logging.h"
#include "gflags/gflags.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
namespace caffe2 {
const char kBcastNet[] =
" name: \"bcast\""
" op {"
" output: \"X\""
" type: \"ConstantFill\""
" arg {"
" name: \"shape\""
" ints: 10"
" }"
" arg {"
" name: \"value\""
" f: 0.0"
" }"
" }"
" op {"
" input: \"X\""
" output: \"X\""
" type: \"Broadcast\""
" arg {"
" name: \"root\""
" i: 0"
" }"
" }";
TEST(MPITest, TestBroadcast) {
NetDef net_def;
CHECK(google::protobuf::TextFormat::ParseFromString(
string(kBcastNet), &net_def));
// Let's set the network's constant fill value to be the mpi rank.
auto* arg = net_def.mutable_op(0)->mutable_arg(1);
CHECK_EQ(arg->name(), "value");
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
arg->set_f(rank);
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
for (int root = 0; root < size; ++root) {
net_def.mutable_op(1)->mutable_arg(0)->set_i(root);
Workspace ws;
unique_ptr<NetBase> net(CreateNet(net_def, &ws));
EXPECT_NE(nullptr, net.get());
EXPECT_TRUE(net->Verify());
EXPECT_TRUE(net->Run());
// Let's test the value.
auto& X = ws.GetBlob("X")->Get<Tensor<float, CPUContext> >();
EXPECT_EQ(X.size(), 10);
for (int i = 0; i < X.size(); ++i) {
EXPECT_EQ(X.data()[i], root);
}
}
}
const char kAllreduceNet[] =
" name: \"allreduce\""
" op {"
" output: \"X\""
" type: \"ConstantFill\""
" arg {"
" name: \"shape\""
" ints: 10"
" }"
" arg {"
" name: \"value\""
" f: 0.0"
" }"
" }"
" op {"
" input: \"X\""
" output: \"X_reduced\""
" type: \"Allreduce\""
" }";
TEST(MPITest, TestAllreduce) {
NetDef net_def;
CHECK(google::protobuf::TextFormat::ParseFromString(
string(kAllreduceNet), &net_def));
// Let's set the network's constant fill value to be the mpi rank.
auto* arg = net_def.mutable_op(0)->mutable_arg(1);
CHECK_EQ(arg->name(), "value");
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
arg->set_f(rank);
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
Workspace ws;
unique_ptr<NetBase> net(CreateNet(net_def, &ws));
EXPECT_NE(nullptr, net.get());
EXPECT_TRUE(net->Verify());
EXPECT_TRUE(net->Run());
// Let's test the value.
auto& X = ws.GetBlob("X")->Get<Tensor<float, CPUContext> >();
EXPECT_EQ(X.size(), 10);
for (int i = 0; i < X.size(); ++i) {
EXPECT_EQ(X.data()[i], rank);
}
auto& X_reduced = ws.GetBlob("X_reduced")->Get<Tensor<float, CPUContext> >();
EXPECT_EQ(X_reduced.size(), 10);
int expected_result = size * (size - 1) / 2;
for (int i = 0; i < X_reduced.size(); ++i) {
EXPECT_EQ(X_reduced.data()[i], expected_result);
}
}
} // namespace caffe2
DEFINE_string(caffe_test_root, "gen/", "The root of the caffe test folder.");
GTEST_API_ int main(int argc, char **argv) {
int mpi_ret;
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_ret);
testing::InitGoogleTest(&argc, argv);
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
int test_result = RUN_ALL_TESTS();
MPI_Finalize();
return test_result;
}