Transform pass to split SparseLengthsSumSparse (#35522)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35522

We will need to apply this transform pass onto the net before lowering to Glow.

Test Plan:
```
buck test caffe2/caffe2/opt/custom:split_slss_test
```

Reviewed By: ipiszy

Differential Revision: D20688451

fbshipit-source-id: 22c0f5d0dcf97cc51cdc86bfc0abd90328ad5f2c
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index d56adc1..87ef403 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -567,6 +567,84 @@
 
 } // namespace
 
+void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws) {
+  const static std::unordered_map<string, string> slss = {
+      {"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"},
+      {"SparseLengthsWeightedSum4BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused4BitRowwise"},
+      {"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSum8FusedBitRowwise"},
+      {"SparseLengthsWeightedSum8BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused8BitRowwise"},
+      {"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"},
+      {"SparseLengthsWeightedSum2BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused2BitRowwise"}};
+  NetDef new_net;
+  new_net.CopyFrom(*net);
+  new_net.mutable_op()->Clear();
+  for (const auto& op : net->op()) {
+    const auto it = slss.find(op.type());
+    if (it == slss.end()) {
+      new_net.add_op()->CopyFrom(op);
+    } else {
+      const bool is_weighted =
+          (op.type().find("Weighted") != std::string::npos);
+      const auto& compressed_mapping = op.input(is_weighted ? 4 : 3);
+      const auto* b = ws.GetBlob(compressed_mapping);
+      bool fallback = false;
+      if (b && b->IsType<Tensor>()) {
+        const auto& t = BlobGetTensor(*b, CPU);
+        fallback = ((t.numel() == 1) && (t.template data<int32_t>()[0] == 0));
+      }
+
+      if (fallback) {
+        // If fallback, we just replace the original slss op with a normal sls
+        // op
+        OperatorDef new_op;
+        new_op.CopyFrom(op);
+        new_op.set_type(it->second);
+        new_op.mutable_input()->RemoveLast();
+        new_net.add_op()->CopyFrom(new_op);
+      } else {
+        // Otherwise, we replace slss with slss_lookup followed by a normal sls
+        OperatorDef new_op;
+        new_op.CopyFrom(op);
+        new_op.set_type("SparseLengthsSumSparseLookup");
+        new_op.clear_input();
+        const auto& indices_in = is_weighted ? op.input(2) : op.input(1);
+        const auto& lengths_in = is_weighted ? op.input(3) : op.input(2);
+        const auto& compress_mapping = is_weighted ? op.input(4) : op.input(3);
+        const auto& weights_in = is_weighted ? op.input(1) : "";
+        new_op.add_input(indices_in);
+        new_op.add_input(lengths_in);
+        new_op.add_input(compress_mapping);
+        const auto indices_out = indices_in + "_decomp";
+        const auto lengths_out = lengths_in + "_decomp";
+        const auto weights_out = weights_in + "_decomp";
+        new_op.clear_output();
+        new_op.add_output(indices_out);
+        new_op.add_output(lengths_out);
+        if (is_weighted) {
+          new_op.add_input(weights_in);
+          new_op.add_output(weights_out);
+        }
+        new_net.add_op()->CopyFrom(new_op);
+
+        new_op.CopyFrom(op);
+        new_op.set_type(it->second);
+        new_op.mutable_input()->RemoveLast();
+        *new_op.mutable_input()->Mutable(is_weighted ? 2 : 1) = indices_out;
+        *new_op.mutable_input()->Mutable(is_weighted ? 3 : 2) = lengths_out;
+        if (is_weighted) {
+          *new_op.mutable_input()->Mutable(1) = weights_out;
+        }
+        new_net.add_op()->CopyFrom(new_op);
+      }
+    }
+  }
+
+  new_net.Swap(net);
+}
+
 OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
     : BackendTransformerBase(), opts_(opts) {
   lib_ = onnx::initOnnxifiLibrary();
diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h
index 99fda23..871f0f7 100644
--- a/caffe2/opt/onnxifi_transformer.h
+++ b/caffe2/opt/onnxifi_transformer.h
@@ -16,6 +16,10 @@
 class OnnxExporter;
 }
 
+// Split SparseLengthsSumSparse into SparseLengthsSumSparseLookup +
+// SparseLengthsSum
+CAFFE2_API void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws);
+
 struct OnnxifiTransformerOptions final : public BackendTransformOptions {
   explicit OnnxifiTransformerOptions() : BackendTransformOptions() {}
 
diff --git a/caffe2/opt/split_slss_test.cc b/caffe2/opt/split_slss_test.cc
new file mode 100644
index 0000000..d02796c
--- /dev/null
+++ b/caffe2/opt/split_slss_test.cc
@@ -0,0 +1,105 @@
+#include <caffe2/core/common.h>
+#include <caffe2/core/test_utils.h>
+#include <caffe2/core/workspace.h>
+#include <caffe2/opt/onnxifi_transformer.h>
+#include <caffe2/utils/proto_utils.h>
+
+#include <gtest/gtest.h>
+
+using namespace caffe2::testing;
+using namespace caffe2;
+
+namespace {
+NetDef createTest(
+    const std::string& op_type,
+    Workspace* ws,
+    bool has_weight,
+    bool fallback) {
+  NetDef net;
+  std::vector<std::string> inputs{
+      "Data", "Weight", "Idx", "Lengths", "Compressed"};
+  if (!has_weight) {
+    inputs = {"Data", "Idx", "Lengths", "Compressed"};
+  }
+  NetMutator(&net).newOp(op_type, inputs, {"Out"});
+  auto* b = ws->CreateBlob("Compressed");
+  auto* t = BlobGetMutableTensor(b, {1}, at::dtype<int32_t>());
+  auto* comp = t->template mutable_data<int32_t>();
+  *comp = fallback ? 0 : 1;
+  return net;
+}
+
+void check(
+    const NetDef& net,
+    const std::string& op_type,
+    bool has_weight,
+    bool fallback) {
+  const static std::unordered_map<string, string> slss = {
+      {"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"},
+      {"SparseLengthsWeightedSum4BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused4BitRowwise"},
+      {"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSum8FusedBitRowwise"},
+      {"SparseLengthsWeightedSum8BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused8BitRowwise"},
+      {"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"},
+      {"SparseLengthsWeightedSum2BitRowwiseSparse",
+       "SparseLengthsWeightedSumFused2BitRowwise"}};
+  if (fallback) {
+    EXPECT_EQ(net.op_size(), 1);
+    EXPECT_EQ(net.op(0).type(), slss.at(op_type));
+    EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
+    EXPECT_EQ(net.op(0).output_size(), 1);
+    EXPECT_EQ(net.op(0).input(0), "Data");
+    EXPECT_EQ(net.op(0).input(has_weight ? 2 : 1), "Idx");
+    EXPECT_EQ(net.op(0).input(has_weight ? 3 : 2), "Lengths");
+    if (has_weight) {
+      EXPECT_EQ(net.op(0).input(1), "Weight");
+    }
+    EXPECT_EQ(net.op(0).output(0), "Out");
+  } else {
+    EXPECT_EQ(net.op_size(), 2);
+    EXPECT_EQ(net.op(0).type(), "SparseLengthsSumSparseLookup");
+    EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
+    EXPECT_EQ(net.op(0).output_size(), has_weight ? 3 : 2);
+    EXPECT_EQ(net.op(0).input(0), "Idx");
+    EXPECT_EQ(net.op(0).input(1), "Lengths");
+    EXPECT_EQ(net.op(0).input(2), "Compressed");
+    EXPECT_EQ(net.op(0).output(0), "Idx_decomp");
+    EXPECT_EQ(net.op(0).output(1), "Lengths_decomp");
+    if (has_weight) {
+      EXPECT_EQ(net.op(0).input(3), "Weight");
+      EXPECT_EQ(net.op(0).output(2), "Weight_decomp");
+    }
+    EXPECT_EQ(net.op(1).type(), slss.at(op_type));
+    EXPECT_EQ(net.op(1).input_size(), has_weight ? 4 : 3);
+    EXPECT_EQ(net.op(1).output_size(), 1);
+    EXPECT_EQ(net.op(1).input(0), "Data");
+    EXPECT_EQ(net.op(1).input(has_weight ? 2 : 1), "Idx_decomp");
+    EXPECT_EQ(net.op(1).input(has_weight ? 3 : 2), "Lengths_decomp");
+    if (has_weight) {
+      EXPECT_EQ(net.op(1).input(1), "Weight_decomp");
+    }
+    EXPECT_EQ(net.op(1).output(0), "Out");
+  }
+}
+} // namespace
+
+TEST(splitSparseLengthsSumSparse, sweep) {
+  std::vector<bool> has_weights = {true, false};
+  std::vector<bool> fallbacks = {true, false};
+  std::vector<int> bits = {2, 4, 8};
+  for (const auto has_weight : has_weights) {
+    for (const auto bit : bits) {
+      std::string op_type = "SparseLengths";
+      op_type += (has_weight ? "WeightedSum" : "Sum");
+      op_type += caffe2::to_string(bit);
+      op_type += "BitRowwiseSparse";
+      for (const auto fallback : fallbacks) {
+        Workspace ws;
+        auto net = createTest(op_type, &ws, has_weight, fallback);
+        splitSparseLengthsSumSparse(&net, ws);
+        check(net, op_type, has_weight, fallback);
+      }
+    }
+  }
+}