Support shape inference and lowering of SparseLengthsWeightedSumFused4BitRowwise (#32257)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32257
Pull Request resolved: https://github.com/pytorch/glow/pull/4018
att.
Test Plan:
Unit tests:
```
buck test glow:masterCaffe2ImporterTest -- caffe2.SparseLengthsSumFused4BitRowwise
buck test caffe2/caffe2/opt:bound_shape_inference_test
```
Reviewed By: jfix71
Differential Revision: D19389014
fbshipit-source-id: 5f6863443adee5d3bf7a50a105866441eefb9560
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
index c31410e..4ea1a28 100644
--- a/caffe2/opt/bound_shape_inference_test.cc
+++ b/caffe2/opt/bound_shape_inference_test.cc
@@ -126,6 +126,51 @@
{spec.max_batch_size, 50});
}
+TEST(BoundShapeInference, SparseLengthsSumFused4BitRowwise) {
+ NetDef net;
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "SparseLengthsSumFused4BitRowwise",
+ "",
+ {"Weights", "Data", "Lengths"},
+ {"Out"},
+ {}));
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "Weights",
+ makeTensorInfo(
+ {TensorBoundShape_DimType_CONSTANT,
+ TensorBoundShape_DimType_CONSTANT},
+ {1000, 54},
+ TensorProto_DataType_INT8));
+ BoundShapeSpec spec(20, 1000);
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map, nullptr);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape,
+ "Weights",
+ {TensorBoundShape_DimType_CONSTANT, TensorBoundShape_DimType_CONSTANT},
+ {1000, 54},
+ TensorProto_DataType_INT8);
+ verifyShapeInfo(
+ out_shape,
+ "Data",
+ {TensorBoundShape_DimType_FEATURE_MAX_DEFAULT},
+ {spec.max_seq_size},
+ TensorProto_DataType_INT64);
+ verifyShapeInfo(
+ out_shape,
+ "Lengths",
+ {TensorBoundShape_DimType_BATCH},
+ {spec.max_batch_size},
+ TensorProto_DataType_INT32);
+ verifyShapeInfo(
+ out_shape,
+ "Out",
+ {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT},
+ {spec.max_batch_size, 100});
+}
+
TEST(BoundShapeInference, LengthsRangeFill) {
NetDef net;
net.add_op()->CopyFrom(
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index 5e9148c..d1fee29 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -65,7 +65,9 @@
if (op.type() == "SparseLengthsSum" ||
op.type() == "SparseLengthsSumFused8BitRowwise" ||
op.type() == "SparseLengthsWeightedSum" ||
- op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
+ op.type() == "SparseLengthsWeightedSumFused8BitRowwise" ||
+ op.type() == "SparseLengthsSumFused4BitRowwise" ||
+ op.type() == "SparseLengthsWeightedSumFused4BitRowwise") {
InferSparseLengthsSum(op);
} else if (
op.type() == "FC" || op.type() == "FCTransposed" ||
@@ -258,10 +260,14 @@
"needs to be 2D");
int weight = (op.type() == "SparseLengthsWeightedSum" ||
- op.type() == "SparseLengthsWeightedSumFused8BitRowwise")
+ op.type() == "SparseLengthsWeightedSumFused8BitRowwise" ||
+ op.type() == "SparseLengthsWeightedSumFused4BitRowwise")
? 1
: 0;
+ const bool is4bit = op.type() == "SparseLengthsSumFused4BitRowwise" ||
+ op.type() == "SparseLengthsWeightedSumFused4BitRowwise";
+
if (weight) {
CAFFE_ENFORCE_EQ(
op.input_size(), 4, "SparseLengthsWeightedSum must have 4 inputs");
@@ -292,12 +298,20 @@
current_dim_type_ = TensorBoundShape_DimType_BATCH;
current_max_batch_size_ = spec_.max_batch_size;
auto output_dim1 = it->second.shape.dims(1);
- // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for
- // scale and 4 byte for bias (https://fburl.com/t6dp9tsc)
+ // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 bytes
+ // for fp32 scale and 4 bytes for fp32 bias (https://fburl.com/t6dp9tsc)
if (op.type() == "SparseLengthsSumFused8BitRowwise" ||
op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
output_dim1 -= 8;
}
+ // If the op is SparseLengthsSumFused4BitRowwise, we need to extract 2 bytes
+ // for fp16 scale and 2 bytes for fp16 bias. Then we double it because we pack
+ // 2 entries into 1 uint8 element of the embedding table.
+ // (https://fburl.com/diffusion/stmsyz74)
+ else if (is4bit) {
+ output_dim1 -= 4;
+ output_dim1 *= 2;
+ }
CAFFE_ENFORCE_GE(
it->second.getDimType().size(), 2, "input(0): ", op.input(0));
CheckAndSetTensorBoundShape(