|  | #include <caffe2/core/common.h> | 
|  | #include <caffe2/core/test_utils.h> | 
|  | #include <caffe2/core/workspace.h> | 
|  | #include <caffe2/opt/onnxifi_transformer.h> | 
|  | #include <caffe2/utils/proto_utils.h> | 
|  |  | 
|  | #include <gtest/gtest.h> | 
|  |  | 
|  | using namespace caffe2::testing; | 
|  | using namespace caffe2; | 
|  |  | 
|  | namespace { | 
|  | NetDef createTest( | 
|  | const std::string& op_type, | 
|  | Workspace* ws, | 
|  | bool has_weight, | 
|  | bool fallback) { | 
|  | NetDef net; | 
|  | std::vector<std::string> inputs{ | 
|  | "Data", "Weight", "Idx", "Lengths", "Compressed"}; | 
|  | if (!has_weight) { | 
|  | inputs = {"Data", "Idx", "Lengths", "Compressed"}; | 
|  | } | 
|  | NetMutator(&net).newOp(op_type, inputs, {"Out"}); | 
|  | auto* b = ws->CreateBlob("Compressed"); | 
|  | auto* t = BlobGetMutableTensor(b, {1}, at::dtype<int32_t>()); | 
|  | auto* comp = t->template mutable_data<int32_t>(); | 
|  | *comp = fallback ? 0 : 1; | 
|  | return net; | 
|  | } | 
|  |  | 
|  | void check( | 
|  | const NetDef& net, | 
|  | const std::string& op_type, | 
|  | bool has_weight, | 
|  | bool fallback) { | 
|  | const static std::unordered_map<string, string> slss = { | 
|  | {"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"}, | 
|  | {"SparseLengthsWeightedSum4BitRowwiseSparse", | 
|  | "SparseLengthsWeightedSumFused4BitRowwise"}, | 
|  | {"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSumFused8BitRowwise"}, | 
|  | {"SparseLengthsWeightedSum8BitRowwiseSparse", | 
|  | "SparseLengthsWeightedSumFused8BitRowwise"}, | 
|  | {"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"}, | 
|  | {"SparseLengthsWeightedSum2BitRowwiseSparse", | 
|  | "SparseLengthsWeightedSumFused2BitRowwise"}}; | 
|  | if (fallback) { | 
|  | EXPECT_EQ(net.op_size(), 1); | 
|  | EXPECT_EQ(net.op(0).type(), slss.at(op_type)); | 
|  | EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3); | 
|  | EXPECT_EQ(net.op(0).output_size(), 1); | 
|  | EXPECT_EQ(net.op(0).input(0), "Data"); | 
|  | EXPECT_EQ(net.op(0).input(has_weight ? 2 : 1), "Idx"); | 
|  | EXPECT_EQ(net.op(0).input(has_weight ? 3 : 2), "Lengths"); | 
|  | if (has_weight) { | 
|  | EXPECT_EQ(net.op(0).input(1), "Weight"); | 
|  | } | 
|  | EXPECT_EQ(net.op(0).output(0), "Out"); | 
|  | } else { | 
|  | EXPECT_EQ(net.op_size(), 2); | 
|  | EXPECT_EQ(net.op(0).type(), "SparseLengthsSumSparseLookup"); | 
|  | EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3); | 
|  | EXPECT_EQ(net.op(0).output_size(), has_weight ? 3 : 2); | 
|  | EXPECT_EQ(net.op(0).input(0), "Idx"); | 
|  | EXPECT_EQ(net.op(0).input(1), "Lengths"); | 
|  | EXPECT_EQ(net.op(0).input(2), "Compressed"); | 
|  | EXPECT_EQ(net.op(0).output(0), "Idx_decomp"); | 
|  | EXPECT_EQ(net.op(0).output(1), "Lengths_decomp"); | 
|  | if (has_weight) { | 
|  | EXPECT_EQ(net.op(0).input(3), "Weight"); | 
|  | EXPECT_EQ(net.op(0).output(2), "Weight_decomp"); | 
|  | } | 
|  | EXPECT_EQ(net.op(1).type(), slss.at(op_type)); | 
|  | EXPECT_EQ(net.op(1).input_size(), has_weight ? 4 : 3); | 
|  | EXPECT_EQ(net.op(1).output_size(), 1); | 
|  | EXPECT_EQ(net.op(1).input(0), "Data"); | 
|  | EXPECT_EQ(net.op(1).input(has_weight ? 2 : 1), "Idx_decomp"); | 
|  | EXPECT_EQ(net.op(1).input(has_weight ? 3 : 2), "Lengths_decomp"); | 
|  | if (has_weight) { | 
|  | EXPECT_EQ(net.op(1).input(1), "Weight_decomp"); | 
|  | } | 
|  | EXPECT_EQ(net.op(1).output(0), "Out"); | 
|  | } | 
|  | } | 
|  | } // namespace | 
|  |  | 
|  | TEST(splitSparseLengthsSumSparse, sweep) { | 
|  | std::vector<bool> has_weights = {true, false}; | 
|  | std::vector<bool> fallbacks = {true, false}; | 
|  | std::vector<int> bits = {2, 4, 8}; | 
|  | for (const auto has_weight : has_weights) { | 
|  | for (const auto bit : bits) { | 
|  | std::string op_type = "SparseLengths"; | 
|  | op_type += (has_weight ? "WeightedSum" : "Sum"); | 
|  | op_type += caffe2::to_string(bit); | 
|  | op_type += "BitRowwiseSparse"; | 
|  | for (const auto fallback : fallbacks) { | 
|  | Workspace ws; | 
|  | auto net = createTest(op_type, &ws, has_weight, fallback); | 
|  | splitSparseLengthsSumSparse(&net, ws); | 
|  | check(net, op_type, has_weight, fallback); | 
|  | } | 
|  | } | 
|  | } | 
|  | } |