blob: 2996fe40bb1f573e92ea8e3d78d6e4797d0fcbcc [file] [log] [blame]
// Unit tests for ProfilingInfo.
#include "caffe2/contrib/prof/profiling_info.h"
#include <google/protobuf/util/message_differencer.h>
#include <gtest/gtest.h>
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace contrib {
namespace prof {
namespace {
const char* kTestProfile = R"(
stats {
name: "op1"
mean: 5
stddev: 7
execution_time {
mean: 5
stddev: 7
count: 11
}
output_profile {
name: "var1"
bytes_used {
mean: 13
stddev: 17
count: 10
}
}
output_profile {
name: "var12"
bytes_used {
mean: 43
stddev: 47
count: 53
}
}
}
stats {
name: ""
mean: 19
stddev: 23
execution_time {
mean: 19
stddev: 23
count: 29
}
output_profile {
name: "var2"
bytes_used {
mean: 31
stddev: 37
count: 41
}
}
}
net_name: "example_net"
)";
const char* kTestNetDefCorrect = R"(
name: "example_net"
op {
name: "op1"
type: "add"
output: "var1"
output: "var12"
}
op {
name: ""
type: "mult"
output: "var2"
}
)";
const char* kTestNetDefPartial = R"(
name: "example_net"
op {
name: "op1"
output: "var1"
}
op {
name: ""
output: "var_NET_WRONG"
}
op {
name: "op_NET_WRONG"
}
)";
TEST(ProfilingInfoTest, CorrectParse) {
NetDef net_def;
ProfDAGProtos profile;
ASSERT_TRUE(TextFormat::ParseFromString(string(kTestProfile), &profile));
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefCorrect), &net_def));
ProfilingInfo info;
EXPECT_TRUE(info.Restore(net_def, profile));
EXPECT_EQ(3, info.getBlobMap().size());
EXPECT_EQ(2, info.getOperatorMap().size());
auto it = info.getBlobMap().find("var1");
ASSERT_NE(info.getBlobMap().end(), it);
EXPECT_FLOAT_EQ(13, it->second.getUsedBytes().getMean());
it = info.getBlobMap().find("var12");
ASSERT_NE(info.getBlobMap().end(), it);
EXPECT_FLOAT_EQ(47, it->second.getUsedBytes().getStddev());
it = info.getBlobMap().find("var2");
ASSERT_NE(info.getBlobMap().end(), it);
EXPECT_FLOAT_EQ(31, it->second.getUsedBytes().getMean());
auto it2 = info.getOperatorMap().find(0);
ASSERT_NE(info.getOperatorMap().end(), it2);
EXPECT_FLOAT_EQ(5, it2->second.getExecutionTimeMs().getMean());
it2 = info.getOperatorMap().find(1);
ASSERT_NE(info.getOperatorMap().end(), it2);
EXPECT_FLOAT_EQ(23, it2->second.getExecutionTimeMs().getStddev());
}
TEST(ProfilingInfoTest, PartialParse) {
NetDef net_def;
ProfDAGProtos profile;
ASSERT_TRUE(TextFormat::ParseFromString(string(kTestProfile), &profile));
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefPartial), &net_def));
ProfilingInfo info;
EXPECT_FALSE(info.Restore(net_def, profile));
EXPECT_EQ(1, info.getBlobMap().size());
EXPECT_EQ(2, info.getOperatorMap().size());
auto it = info.getBlobMap().find("var1");
ASSERT_NE(info.getBlobMap().end(), it);
EXPECT_FLOAT_EQ(13, it->second.getUsedBytes().getMean());
auto it2 = info.getOperatorMap().find(0);
ASSERT_NE(info.getOperatorMap().end(), it2);
EXPECT_FLOAT_EQ(5, it2->second.getExecutionTimeMs().getMean());
it2 = info.getOperatorMap().find(1);
ASSERT_NE(info.getOperatorMap().end(), it2);
EXPECT_FLOAT_EQ(23, it2->second.getExecutionTimeMs().getStddev());
}
TEST(ProfilingInfoTest, InitAndAddStats) {
NetDef net_def;
ProfilingInfo info;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefCorrect), &net_def));
info.Init(net_def);
// Expect correct number of items.
EXPECT_EQ(3, info.getBlobMap().size());
EXPECT_EQ(2, info.getOperatorMap().size());
// Expect correct key contents.
auto mutable_blob_it = info.getMutableBlobMap()->find("var1");
ASSERT_NE(info.getMutableBlobMap()->end(), mutable_blob_it);
// Example profiling call:
mutable_blob_it->second.getMutableUsedBytes()->addPoint(3);
auto const_blob_it = info.getBlobMap().find("var12");
EXPECT_NE(info.getBlobMap().end(), const_blob_it);
const_blob_it = info.getBlobMap().find("var2");
EXPECT_NE(info.getBlobMap().end(), const_blob_it);
auto mutable_op_it = info.getMutableOperatorMap()->find(0);
ASSERT_NE(info.getMutableOperatorMap()->end(), mutable_op_it);
// Example profiling call:
mutable_op_it->second.getMutableExecutionTimeMs()->addPoint(3);
auto const_op_it = info.getOperatorMap().find(1);
EXPECT_NE(info.getOperatorMap().end(), const_op_it);
}
const char* kExpectedProto = R"(
stats {
name: "op1"
mean: 5
stddev: 0
execution_time {
mean: 5
stddev: 0
count: 1
}
output_profile {
name: "var1"
bytes_used {
mean: 3
stddev: 0
count: 1
}
}
output_profile {
name: "var12"
bytes_used {
mean: 0
stddev: 0
count: 0
}
}
}
stats {
name: ""
mean: 0
stddev: 0
execution_time {
mean: 0
stddev: 0
count: 0
}
output_profile {
name: "var2"
bytes_used {
mean: 0
stddev: 0
count: 0
}
}
}
net_name: "example_net"
)";
TEST(ProfilingInfoTest, InitAddStatsAndGetOperatorStatsProto) {
NetDef net_def;
ProfilingInfo info;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefCorrect), &net_def));
ProfDAGProtos expected;
ASSERT_TRUE(TextFormat::ParseFromString(string(kExpectedProto), &expected));
info.Init(net_def);
// Add stats.
auto op_it = info.getMutableOperatorMap()->find(0);
op_it->second.getMutableExecutionTimeMs()->addPoint(5);
auto blob_it = info.getMutableBlobMap()->find("var1");
blob_it->second.getMutableUsedBytes()->addPoint(3);
// Export to proto.
ProfDAGProtos generated;
ASSERT_TRUE(info.GetOperatorAndDataStats(net_def, false, &generated));
EXPECT_TRUE(google::protobuf::util::MessageDifferencer::ApproximatelyEquals(
expected, generated))
<< generated.DebugString();
}
TEST(ProfilingInfoTest, NewFormatSymmetry) {
NetDef net_def;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefCorrect), &net_def));
ProfDAGProtos profile;
ASSERT_TRUE(TextFormat::ParseFromString(string(kTestProfile), &profile));
ProfilingInfo info;
info.Restore(net_def, profile);
ProfDAGProtos generated;
ASSERT_TRUE(info.GetOperatorAndDataStats(net_def, false, &generated));
EXPECT_TRUE(google::protobuf::util::MessageDifferencer::ApproximatelyEquals(
profile, generated))
<< generated.DebugString();
}
const char* kExpectedProtoOldFormat = R"(
stats {
name: "example_net___0___add"
mean: 5
stddev: 0
}
stats {
name: "example_net___1___mult"
mean: 0
stddev: 0
}
)";
TEST(ProfilingInfoTest, InitAddStatsAndGetOperatorStatsProtoOldFormat) {
NetDef net_def;
ProfilingInfo info;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kTestNetDefCorrect), &net_def));
ProfDAGProtos expected;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kExpectedProtoOldFormat), &expected));
info.Init(net_def);
// Add stats.
auto op_it = info.getMutableOperatorMap()->find(0);
op_it->second.getMutableExecutionTimeMs()->addPoint(5);
auto blob_it = info.getMutableBlobMap()->find("var1");
blob_it->second.getMutableUsedBytes()->addPoint(3);
// Export to proto.
ProfDAGProtos generated;
ASSERT_TRUE(info.GetOperatorAndDataStats(net_def, true, &generated));
EXPECT_TRUE(
google::protobuf::util::MessageDifferencer::ApproximatelyEquivalent(
expected, generated))
<< generated.DebugString();
}
const char* kTestNetDefMultiOpMultiType = R"(
name: "example_net"
op {
name: "op1"
type: "add"
output: "var1"
output: "var12"
}
op {
name: "op2"
type: "add"
output: "var12"
}
op {
name: "op3"
type: "mult"
output: "var2"
}
)";
const char* kExpectedProtoByType = R"(
stats {
name: "add"
mean: 6
stddev: 1
}
stats {
name: "mult"
mean: 11
stddev: 0
}
)";
TEST(ProfilingInfoTest, InitAddStatsAndGetOperatorTypeStatsProto) {
NetDef net_def;
ProfilingInfo info;
ASSERT_TRUE(TextFormat::ParseFromString(
string(kTestNetDefMultiOpMultiType), &net_def));
ProfDAGProtos expected;
ASSERT_TRUE(
TextFormat::ParseFromString(string(kExpectedProtoByType), &expected));
info.Init(net_def);
// Add stats.
auto op_it = info.getMutableOperatorMap()->find(0);
op_it->second.getMutableExecutionTimeMs()->addPoint(5);
op_it = info.getMutableOperatorMap()->find(1);
op_it->second.getMutableExecutionTimeMs()->addPoint(7);
op_it = info.getMutableOperatorMap()->find(2);
op_it->second.getMutableExecutionTimeMs()->addPoint(11);
// Export to proto.
ProfDAGProtos generated;
ASSERT_TRUE(info.GetOperatorTypeStats(net_def, &generated));
// Since op types are unordered, it is OK to match either.
ProfDAGProtos generated_swapped = generated;
std::swap(
*generated_swapped.mutable_stats(0), *generated_swapped.mutable_stats(1));
EXPECT_TRUE(
google::protobuf::util::MessageDifferencer::ApproximatelyEquivalent(
expected, generated) ||
google::protobuf::util::MessageDifferencer::ApproximatelyEquivalent(
expected, generated_swapped))
<< generated.DebugString();
}
} // namespace
} // namespace prof
} // namespace contrib
} // namespace caffe2