Add support for ops with multiple outputs (i.e. Split).
PiperOrigin-RevId: 266053965
diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD
index 865f1e9..a7ff3c5 100644
--- a/tensorflow/lite/tools/optimize/BUILD
+++ b/tensorflow/lite/tools/optimize/BUILD
@@ -191,6 +191,7 @@
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
"//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
+ "//tensorflow/lite/tools/optimize:testdata/split.bin",
],
tags = [
"tflite_not_portable_android",
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 8316dec..f0fc409 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -45,6 +45,13 @@
property.restrict_same_input_output_scale = true;
property.version = 2;
break;
+ case BuiltinOperator_SPLIT:
+ property.arbitrary_outputs = true;
+ // We skip input 0 since it is the split dim which is not real valued.
+ property.inputs = {{1, {}}};
+ property.restrict_same_input_output_scale = true;
+ property.version = 2;
+ break;
case BuiltinOperator_CONCATENATION:
property.arbitrary_inputs = true;
property.outputs = {{0, {}}};
diff --git a/tensorflow/lite/tools/optimize/operator_property.h b/tensorflow/lite/tools/optimize/operator_property.h
index a0d1630..b3c9975 100644
--- a/tensorflow/lite/tools/optimize/operator_property.h
+++ b/tensorflow/lite/tools/optimize/operator_property.h
@@ -41,6 +41,8 @@
// Op has arbitrary number of inputs, such as concat.
bool arbitrary_inputs = false;
+ // Op has arbitrary number of outputs, such as slice.
+ bool arbitrary_outputs = false;
// Input indexes -> input tensor property.
std::vector<std::pair<int, TensorProperty>> inputs = {};
// Output indexes -> output tensor property.
diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc
index 38cf769..23bf264 100644
--- a/tensorflow/lite/tools/optimize/quantize_model.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model.cc
@@ -297,9 +297,8 @@
continue;
}
// Basically only Concat passes this check.
- if (!property.restrict_same_input_output_scale ||
- (property.inputs.size() == 1 && property.outputs.size() == 1 &&
- property.biases.empty())) {
+ if (!property.arbitrary_inputs ||
+ !property.restrict_same_input_output_scale) {
continue;
}
// If ApplyConstraints and requant is needed, use the min of min and max
@@ -367,10 +366,24 @@
return inputs;
}
+std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
+ const OperatorT* op, operator_property::OperatorProperty property) {
+ std::vector<std::pair<int, operator_property::TensorProperty>> outputs;
+ if (property.arbitrary_outputs) {
+ for (int i = 0; i < op->outputs.size(); ++i) {
+ outputs.push_back({i, {}});
+ }
+ } else {
+ outputs = property.outputs;
+ }
+ return outputs;
+}
+
bool ShouldRestrictSameInputOutputScale(
operator_property::OperatorProperty property) {
- return (property.inputs.size() == 1 && property.outputs.size() == 1 &&
- property.biases.empty() && property.restrict_same_input_output_scale);
+ // Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
+ return (!property.arbitrary_inputs &&
+ property.restrict_same_input_output_scale);
}
bool IsSubgraphInput(SubGraphT* subgraph, int32_t index) {
@@ -541,8 +554,8 @@
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
output_tensor->quantization->scale.push_back(input_scale);
output_tensor->quantization->zero_point.push_back(input_zero_point);
- output_tensor->quantization->min.push_back(min);
- output_tensor->quantization->max.push_back(max);
+ output_tensor->quantization->min = {min};
+ output_tensor->quantization->max = {max};
output_tensor->type = TensorType_INT8;
} else if (tensor_property.restriction) {
const auto scale_and_zp = tensor_property.restricted_value;
@@ -597,7 +610,7 @@
// Quantize operator outputs.
for (const std::pair<int, operator_property::TensorProperty>& output :
- property.outputs) {
+ GetOutputs(op, property)) {
TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
model, subgraph_idx, op_idx, property, output, error_reporter));
}
diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc
index ecf4deb..679b681 100644
--- a/tensorflow/lite/tools/optimize/quantize_model_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc
@@ -402,6 +402,72 @@
EXPECT_EQ(model_.operator_codes[1]->version, 2);
}
+class QuantizeSplitModelTest : public QuantizeModelTest {
+ protected:
+ QuantizeSplitModelTest() {
+ input_model_ = ReadModel(internal::kModelSplit);
+ readonly_model_ = input_model_->GetModel();
+ readonly_model_->UnPackTo(&model_);
+ }
+};
+
+// There are two outputs for split with different scales, the resulting model
+// should have the scales be hardcodes to the input scale value.
+TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
+ auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
+ TensorType_INT8, &error_reporter_);
+ EXPECT_EQ(status, kTfLiteOk);
+
+ // There is only one subgraph.
+ const int32_t subgraph_idx = 0;
+ const auto& subgraph = model_.subgraphs[subgraph_idx];
+ const auto& readonly_subgraph =
+ readonly_model_->subgraphs()->Get(subgraph_idx);
+
+ // There should be two ops: the split and add in the original model.
+ EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
+ EXPECT_EQ(subgraph->operators.size(), 2);
+ const auto& split = subgraph->operators[0];
+ const auto& add = subgraph->operators[1];
+ EXPECT_EQ(model_.operator_codes[split->opcode_index]->builtin_code,
+ BuiltinOperator_SPLIT);
+ EXPECT_EQ(model_.operator_codes[add->opcode_index]->builtin_code,
+ BuiltinOperator_ADD);
+
+ // There should be 5 tensors: input, output, split, split/split_dim, split:1.
+ EXPECT_EQ(subgraph->tensors.size(), 5);
+
+ EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
+ EXPECT_EQ(subgraph->tensors[0]->name, "input");
+ EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
+ EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
+ EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 1.0);
+ EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
+ EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
+ EXPECT_EQ(subgraph->tensors[1]->name, "output");
+ EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
+ EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
+ EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 1.0);
+ EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
+ EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
+ EXPECT_EQ(subgraph->tensors[2]->name, "split");
+ EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
+ EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
+ EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 1.0);
+ EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
+ EXPECT_EQ(subgraph->tensors[4]->type, TensorType_INT8);
+ EXPECT_EQ(subgraph->tensors[4]->name, "split:1");
+ EXPECT_EQ(subgraph->tensors[4]->quantization->scale.size(), 1);
+ EXPECT_EQ(subgraph->tensors[4]->quantization->zero_point.size(), 1);
+ EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->scale[0], 1.0);
+ EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->zero_point[0], -128);
+
+ // check op and versioning.
+ EXPECT_EQ(model_.operator_codes.size(), 2);
+ EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_SPLIT);
+ EXPECT_EQ(model_.operator_codes[0]->version, 2);
+}
+
class QuantizeConvModel1Test : public QuantizeModelTest {
protected:
QuantizeConvModel1Test() {
diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc
index 5f38d9a..3cfd5f5b 100644
--- a/tensorflow/lite/tools/optimize/test_util.cc
+++ b/tensorflow/lite/tools/optimize/test_util.cc
@@ -47,6 +47,8 @@
const char* kModelMixed = "mixed.bin";
+const char* kModelSplit = "split.bin";
+
int FailOnErrorReporter::Report(const char* format, va_list args) {
char buf[1024];
vsnprintf(buf, sizeof(buf), format, args);
diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h
index 1e7e14c..bf42a30b 100644
--- a/tensorflow/lite/tools/optimize/test_util.h
+++ b/tensorflow/lite/tools/optimize/test_util.h
@@ -73,6 +73,9 @@
// reshape->custom->custom->squeeze.
extern const char* kModelMixed;
+// Test model with split op.
+extern const char* kModelSplit;
+
// An error reporter that fails on testing.
class FailOnErrorReporter : public ErrorReporter {
public:
diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/lite/tools/optimize/testdata/split.bin
new file mode 100644
index 0000000..3341df9
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/testdata/split.bin
Binary files differ