Support shape inference and lowering of SparseLengthsWeightedSumFused4BitRowwise (#32257)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32257

Pull Request resolved: https://github.com/pytorch/glow/pull/4018

att.

Test Plan:
Unit tests:
```
buck test glow:masterCaffe2ImporterTest -- caffe2.SparseLengthsSumFused4BitRowwise
buck test caffe2/caffe2/opt:bound_shape_inference_test
```

Reviewed By: jfix71

Differential Revision: D19389014

fbshipit-source-id: 5f6863443adee5d3bf7a50a105866441eefb9560
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
index c31410e..4ea1a28 100644
--- a/caffe2/opt/bound_shape_inference_test.cc
+++ b/caffe2/opt/bound_shape_inference_test.cc
@@ -126,6 +126,51 @@
       {spec.max_batch_size, 50});
 }
 
+TEST(BoundShapeInference, SparseLengthsSumFused4BitRowwise) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "SparseLengthsSumFused4BitRowwise",
+      "",
+      {"Weights", "Data", "Lengths"},
+      {"Out"},
+      {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "Weights",
+      makeTensorInfo(
+          {TensorBoundShape_DimType_CONSTANT,
+           TensorBoundShape_DimType_CONSTANT},
+          {1000, 54},
+          TensorProto_DataType_INT8));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map, nullptr);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape,
+      "Weights",
+      {TensorBoundShape_DimType_CONSTANT, TensorBoundShape_DimType_CONSTANT},
+      {1000, 54},
+      TensorProto_DataType_INT8);
+  verifyShapeInfo(
+      out_shape,
+      "Data",
+      {TensorBoundShape_DimType_FEATURE_MAX_DEFAULT},
+      {spec.max_seq_size},
+      TensorProto_DataType_INT64);
+  verifyShapeInfo(
+      out_shape,
+      "Lengths",
+      {TensorBoundShape_DimType_BATCH},
+      {spec.max_batch_size},
+      TensorProto_DataType_INT32);
+  verifyShapeInfo(
+      out_shape,
+      "Out",
+      {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT},
+      {spec.max_batch_size, 100});
+}
+
 TEST(BoundShapeInference, LengthsRangeFill) {
   NetDef net;
   net.add_op()->CopyFrom(
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index 5e9148c..d1fee29 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -65,7 +65,9 @@
   if (op.type() == "SparseLengthsSum" ||
       op.type() == "SparseLengthsSumFused8BitRowwise" ||
       op.type() == "SparseLengthsWeightedSum" ||
-      op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
+      op.type() == "SparseLengthsWeightedSumFused8BitRowwise" ||
+      op.type() == "SparseLengthsSumFused4BitRowwise" ||
+      op.type() == "SparseLengthsWeightedSumFused4BitRowwise") {
     InferSparseLengthsSum(op);
   } else if (
       op.type() == "FC" || op.type() == "FCTransposed" ||
@@ -258,10 +260,14 @@
       "needs to be 2D");
 
   int weight = (op.type() == "SparseLengthsWeightedSum" ||
-                op.type() == "SparseLengthsWeightedSumFused8BitRowwise")
+                op.type() == "SparseLengthsWeightedSumFused8BitRowwise" ||
+                op.type() == "SparseLengthsWeightedSumFused4BitRowwise")
       ? 1
       : 0;
 
+  const bool is4bit = op.type() == "SparseLengthsSumFused4BitRowwise" ||
+      op.type() == "SparseLengthsWeightedSumFused4BitRowwise";
+
   if (weight) {
     CAFFE_ENFORCE_EQ(
         op.input_size(), 4, "SparseLengthsWeightedSum must have 4 inputs");
@@ -292,12 +298,20 @@
   current_dim_type_ = TensorBoundShape_DimType_BATCH;
   current_max_batch_size_ = spec_.max_batch_size;
   auto output_dim1 = it->second.shape.dims(1);
-  // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for
-  // scale and 4 byte for bias (https://fburl.com/t6dp9tsc)
+  // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 bytes
+  // for fp32 scale and 4 bytes for fp32 bias (https://fburl.com/t6dp9tsc)
   if (op.type() == "SparseLengthsSumFused8BitRowwise" ||
       op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
     output_dim1 -= 8;
   }
+  // If the op is SparseLengthsSumFused4BitRowwise, we need to extract 2 bytes
+  // for fp16 scale and 2 bytes for fp16 bias. Then we double it because we pack
+  // 2 entries into 1 uint8 element of the embedding table.
+  // (https://fburl.com/diffusion/stmsyz74)
+  else if (is4bit) {
+    output_dim1 -= 4;
+    output_dim1 *= 2;
+  }
   CAFFE_ENFORCE_GE(
       it->second.getDimType().size(), 2, "input(0): ", op.input(0));
   CheckAndSetTensorBoundShape(