Transform pass to split SparseLengthsSumSparse (#35522)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35522
We will need to apply this transform pass onto the net before lowering to Glow.
Test Plan:
```
buck test caffe2/caffe2/opt/custom:split_slss_test
```
Reviewed By: ipiszy
Differential Revision: D20688451
fbshipit-source-id: 22c0f5d0dcf97cc51cdc86bfc0abd90328ad5f2c
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index d56adc1..87ef403 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -567,6 +567,84 @@
} // namespace
+void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws) {
+ const static std::unordered_map<string, string> slss = {
+ {"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"},
+ {"SparseLengthsWeightedSum4BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused4BitRowwise"},
+ {"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSum8FusedBitRowwise"},
+ {"SparseLengthsWeightedSum8BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused8BitRowwise"},
+ {"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"},
+ {"SparseLengthsWeightedSum2BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused2BitRowwise"}};
+ NetDef new_net;
+ new_net.CopyFrom(*net);
+ new_net.mutable_op()->Clear();
+ for (const auto& op : net->op()) {
+ const auto it = slss.find(op.type());
+ if (it == slss.end()) {
+ new_net.add_op()->CopyFrom(op);
+ } else {
+ const bool is_weighted =
+ (op.type().find("Weighted") != std::string::npos);
+ const auto& compressed_mapping = op.input(is_weighted ? 4 : 3);
+ const auto* b = ws.GetBlob(compressed_mapping);
+ bool fallback = false;
+ if (b && b->IsType<Tensor>()) {
+ const auto& t = BlobGetTensor(*b, CPU);
+ fallback = ((t.numel() == 1) && (t.template data<int32_t>()[0] == 0));
+ }
+
+ if (fallback) {
+ // If fallback, we just replace the original slss op with a normal sls
+ // op
+ OperatorDef new_op;
+ new_op.CopyFrom(op);
+ new_op.set_type(it->second);
+ new_op.mutable_input()->RemoveLast();
+ new_net.add_op()->CopyFrom(new_op);
+ } else {
+ // Otherwise, we replace slss with slss_lookup followed by a normal sls
+ OperatorDef new_op;
+ new_op.CopyFrom(op);
+ new_op.set_type("SparseLengthsSumSparseLookup");
+ new_op.clear_input();
+ const auto& indices_in = is_weighted ? op.input(2) : op.input(1);
+ const auto& lengths_in = is_weighted ? op.input(3) : op.input(2);
+ const auto& compress_mapping = is_weighted ? op.input(4) : op.input(3);
+ const auto& weights_in = is_weighted ? op.input(1) : "";
+ new_op.add_input(indices_in);
+ new_op.add_input(lengths_in);
+ new_op.add_input(compress_mapping);
+ const auto indices_out = indices_in + "_decomp";
+ const auto lengths_out = lengths_in + "_decomp";
+ const auto weights_out = weights_in + "_decomp";
+ new_op.clear_output();
+ new_op.add_output(indices_out);
+ new_op.add_output(lengths_out);
+ if (is_weighted) {
+ new_op.add_input(weights_in);
+ new_op.add_output(weights_out);
+ }
+ new_net.add_op()->CopyFrom(new_op);
+
+ new_op.CopyFrom(op);
+ new_op.set_type(it->second);
+ new_op.mutable_input()->RemoveLast();
+ *new_op.mutable_input()->Mutable(is_weighted ? 2 : 1) = indices_out;
+ *new_op.mutable_input()->Mutable(is_weighted ? 3 : 2) = lengths_out;
+ if (is_weighted) {
+ *new_op.mutable_input()->Mutable(1) = weights_out;
+ }
+ new_net.add_op()->CopyFrom(new_op);
+ }
+ }
+ }
+
+ new_net.Swap(net);
+}
+
OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
: BackendTransformerBase(), opts_(opts) {
lib_ = onnx::initOnnxifiLibrary();
diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h
index 99fda23..871f0f7 100644
--- a/caffe2/opt/onnxifi_transformer.h
+++ b/caffe2/opt/onnxifi_transformer.h
@@ -16,6 +16,10 @@
class OnnxExporter;
}
+// Split SparseLengthsSumSparse into SparseLengthsSumSparseLookup +
+// SparseLengthsSum
+CAFFE2_API void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws);
+
struct OnnxifiTransformerOptions final : public BackendTransformOptions {
explicit OnnxifiTransformerOptions() : BackendTransformOptions() {}
diff --git a/caffe2/opt/split_slss_test.cc b/caffe2/opt/split_slss_test.cc
new file mode 100644
index 0000000..d02796c
--- /dev/null
+++ b/caffe2/opt/split_slss_test.cc
@@ -0,0 +1,105 @@
+#include <caffe2/core/common.h>
+#include <caffe2/core/test_utils.h>
+#include <caffe2/core/workspace.h>
+#include <caffe2/opt/onnxifi_transformer.h>
+#include <caffe2/utils/proto_utils.h>
+
+#include <gtest/gtest.h>
+
+using namespace caffe2::testing;
+using namespace caffe2;
+
+namespace {
+NetDef createTest(
+ const std::string& op_type,
+ Workspace* ws,
+ bool has_weight,
+ bool fallback) {
+ NetDef net;
+ std::vector<std::string> inputs{
+ "Data", "Weight", "Idx", "Lengths", "Compressed"};
+ if (!has_weight) {
+ inputs = {"Data", "Idx", "Lengths", "Compressed"};
+ }
+ NetMutator(&net).newOp(op_type, inputs, {"Out"});
+ auto* b = ws->CreateBlob("Compressed");
+ auto* t = BlobGetMutableTensor(b, {1}, at::dtype<int32_t>());
+ auto* comp = t->template mutable_data<int32_t>();
+ *comp = fallback ? 0 : 1;
+ return net;
+}
+
+void check(
+ const NetDef& net,
+ const std::string& op_type,
+ bool has_weight,
+ bool fallback) {
+ const static std::unordered_map<string, string> slss = {
+ {"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"},
+ {"SparseLengthsWeightedSum4BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused4BitRowwise"},
+ {"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSum8FusedBitRowwise"},
+ {"SparseLengthsWeightedSum8BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused8BitRowwise"},
+ {"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"},
+ {"SparseLengthsWeightedSum2BitRowwiseSparse",
+ "SparseLengthsWeightedSumFused2BitRowwise"}};
+ if (fallback) {
+ EXPECT_EQ(net.op_size(), 1);
+ EXPECT_EQ(net.op(0).type(), slss.at(op_type));
+ EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
+ EXPECT_EQ(net.op(0).output_size(), 1);
+ EXPECT_EQ(net.op(0).input(0), "Data");
+ EXPECT_EQ(net.op(0).input(has_weight ? 2 : 1), "Idx");
+ EXPECT_EQ(net.op(0).input(has_weight ? 3 : 2), "Lengths");
+ if (has_weight) {
+ EXPECT_EQ(net.op(0).input(1), "Weight");
+ }
+ EXPECT_EQ(net.op(0).output(0), "Out");
+ } else {
+ EXPECT_EQ(net.op_size(), 2);
+ EXPECT_EQ(net.op(0).type(), "SparseLengthsSumSparseLookup");
+ EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
+ EXPECT_EQ(net.op(0).output_size(), has_weight ? 3 : 2);
+ EXPECT_EQ(net.op(0).input(0), "Idx");
+ EXPECT_EQ(net.op(0).input(1), "Lengths");
+ EXPECT_EQ(net.op(0).input(2), "Compressed");
+ EXPECT_EQ(net.op(0).output(0), "Idx_decomp");
+ EXPECT_EQ(net.op(0).output(1), "Lengths_decomp");
+ if (has_weight) {
+ EXPECT_EQ(net.op(0).input(3), "Weight");
+ EXPECT_EQ(net.op(0).output(2), "Weight_decomp");
+ }
+ EXPECT_EQ(net.op(1).type(), slss.at(op_type));
+ EXPECT_EQ(net.op(1).input_size(), has_weight ? 4 : 3);
+ EXPECT_EQ(net.op(1).output_size(), 1);
+ EXPECT_EQ(net.op(1).input(0), "Data");
+ EXPECT_EQ(net.op(1).input(has_weight ? 2 : 1), "Idx_decomp");
+ EXPECT_EQ(net.op(1).input(has_weight ? 3 : 2), "Lengths_decomp");
+ if (has_weight) {
+ EXPECT_EQ(net.op(1).input(1), "Weight_decomp");
+ }
+ EXPECT_EQ(net.op(1).output(0), "Out");
+ }
+}
+} // namespace
+
+TEST(splitSparseLengthsSumSparse, sweep) {
+ std::vector<bool> has_weights = {true, false};
+ std::vector<bool> fallbacks = {true, false};
+ std::vector<int> bits = {2, 4, 8};
+ for (const auto has_weight : has_weights) {
+ for (const auto bit : bits) {
+ std::string op_type = "SparseLengths";
+ op_type += (has_weight ? "WeightedSum" : "Sum");
+ op_type += caffe2::to_string(bit);
+ op_type += "BitRowwiseSparse";
+ for (const auto fallback : fallbacks) {
+ Workspace ws;
+ auto net = createTest(op_type, &ws, has_weight, fallback);
+ splitSparseLengthsSumSparse(&net, ws);
+ check(net, op_type, has_weight, fallback);
+ }
+ }
+ }
+}