blob: 38ec576b99a4fb3a591d7c663a39c3356573a579 [file] [log] [blame]
#include <gtest/gtest.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/opt/bound_shape_inferencer.h"
#include "caffe2/utils/proto_utils.h"
using namespace caffe2;
namespace {
using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
ShapeInfo MakeTensorInfo(
ShapeInfo::DimType t,
const std::vector<int64_t>& dims,
TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
ShapeInfo info;
info.dim_type = t;
TensorShape& shape = info.shape;
for (const auto d : dims) {
shape.add_dims(d);
}
shape.set_data_type(dtype);
return info;
}
void PrintShape(const ShapeInfoMap& map) {
for (const auto& kv : map) {
const auto& s = kv.second;
std::stringstream ss;
ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
for (const auto d : s.shape.dims()) {
ss << d << ", ";
}
ss << "], dtype: " << s.shape.data_type();
LOG(INFO) << ss.str();
}
}
void VerifyShapeInfo(
const ShapeInfoMap& info,
const std::string& name,
ShapeInfo::DimType t,
const std::vector<int64_t>& dims,
TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
LOG(INFO) << "Checking " << name;
const auto it = info.find(name);
ASSERT_TRUE(it != info.end());
const auto& shape_info = it->second;
EXPECT_EQ(shape_info.dim_type, t);
const auto& shape = shape_info.shape;
ASSERT_EQ(shape.dims_size(), dims.size());
for (int i = 0; i < dims.size(); ++i) {
EXPECT_EQ(shape.dims(i), dims[i]);
}
EXPECT_EQ(shape.data_type(), dtype);
}
} // namespace
TEST(BoundShapeInference, SparseLengthsSum) {
NetDef net;
net.add_op()->CopyFrom(CreateOperatorDef(
"SparseLengthsSum", "", {"Weights", "Data", "Lengths"}, {"Out"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
"Weights", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
VerifyShapeInfo(
out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {16, 1000});
VerifyShapeInfo(
out_shape,
"Data",
ShapeInfo::DimType::SEQ,
{spec.max_seq_size},
TensorProto_DataType_INT32);
VerifyShapeInfo(
out_shape,
"Lengths",
ShapeInfo::DimType::BATCH,
{spec.max_batch_size},
TensorProto_DataType_INT32);
VerifyShapeInfo(
out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
}
TEST(BoundShapeInference, FC) {
NetDef net;
net.add_op()->CopyFrom(
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
net.add_op()->CopyFrom(
CreateOperatorDef("FCTransposed", "", {"X1", "W1", "B1"}, {"Out1"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
"W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
shape_map.emplace(
"W1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
shape_map.emplace("B1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
VerifyShapeInfo(
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
VerifyShapeInfo(
out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
VerifyShapeInfo(
out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
VerifyShapeInfo(
out_shape,
"Out1",
ShapeInfo::DimType::BATCH,
{spec.max_batch_size, 1024});
}
// We don't support inference input shape when Weight is not 2D
TEST(BoundShapeInference, UnsupportedFC) {
NetDef net;
net.add_op()->CopyFrom(
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
"W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
}
TEST(BoundShapeInference, Combo0) {
NetDef net;
net.add_op()->CopyFrom(CreateOperatorDef(
"SparseLengthsSum", "", {"Weights0", "Data0", "Lengths0"}, {"EB0"}, {}));
net.add_op()->CopyFrom(CreateOperatorDef(
"SparseLengthsSum", "", {"Weights1", "Data1", "Lengths1"}, {"EB1"}, {}));
net.add_op()->CopyFrom(CreateOperatorDef(
"Concat",
"",
{"EB0", "EB1"},
{"Cout", "split_info"},
{MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
net.add_op()->CopyFrom(CreateOperatorDef(
"BatchMatMul",
"",
{"Cout", "Cout"},
{"Bout"},
{MakeArgument<int>("trans_b", 1)}));
net.add_op()->CopyFrom(
CreateOperatorDef("Flatten", "", {"Bout"}, {"Fout"}, {}));
net.add_op()->CopyFrom(
CreateOperatorDef("BatchGather", "", {"Fout", "Indices"}, {"Gout"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
"Weights0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
shape_map.emplace(
"Weights1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 20000}));
shape_map.emplace(
"Indices", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
PrintShape(out_shape);
VerifyShapeInfo(
out_shape, "Gout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2});
}