Add support for ops with multiple outputs (i.e. Split).

PiperOrigin-RevId: 266053965
diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD
index 865f1e9..a7ff3c5 100644
--- a/tensorflow/lite/tools/optimize/BUILD
+++ b/tensorflow/lite/tools/optimize/BUILD
@@ -191,6 +191,7 @@
         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
         "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
+        "//tensorflow/lite/tools/optimize:testdata/split.bin",
     ],
     tags = [
         "tflite_not_portable_android",
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 8316dec..f0fc409 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -45,6 +45,13 @@
       property.restrict_same_input_output_scale = true;
       property.version = 2;
       break;
+    case BuiltinOperator_SPLIT:
+      property.arbitrary_outputs = true;
+      // We skip input 0 since it is the split dim which is not real valued.
+      property.inputs = {{1, {}}};
+      property.restrict_same_input_output_scale = true;
+      property.version = 2;
+      break;
     case BuiltinOperator_CONCATENATION:
       property.arbitrary_inputs = true;
       property.outputs = {{0, {}}};
diff --git a/tensorflow/lite/tools/optimize/operator_property.h b/tensorflow/lite/tools/optimize/operator_property.h
index a0d1630..b3c9975 100644
--- a/tensorflow/lite/tools/optimize/operator_property.h
+++ b/tensorflow/lite/tools/optimize/operator_property.h
@@ -41,6 +41,8 @@
 
   // Op has arbitrary number of inputs, such as concat.
   bool arbitrary_inputs = false;
+  // Op has arbitrary number of outputs, such as slice.
+  bool arbitrary_outputs = false;
   // Input indexes -> input tensor property.
   std::vector<std::pair<int, TensorProperty>> inputs = {};
   // Output indexes -> output tensor property.
diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc
index 38cf769..23bf264 100644
--- a/tensorflow/lite/tools/optimize/quantize_model.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model.cc
@@ -297,9 +297,8 @@
         continue;
       }
       // Basically only Concat passes this check.
-      if (!property.restrict_same_input_output_scale ||
-          (property.inputs.size() == 1 && property.outputs.size() == 1 &&
-           property.biases.empty())) {
+      if (!property.arbitrary_inputs ||
+          !property.restrict_same_input_output_scale) {
         continue;
       }
       // If ApplyConstraints and requant is needed, use the min of min and max
@@ -367,10 +366,24 @@
   return inputs;
 }
 
+std::vector<std::pair<int, operator_property::TensorProperty>> GetOutputs(
+    const OperatorT* op, operator_property::OperatorProperty property) {
+  std::vector<std::pair<int, operator_property::TensorProperty>> outputs;
+  if (property.arbitrary_outputs) {
+    for (int i = 0; i < op->outputs.size(); ++i) {
+      outputs.push_back({i, {}});
+    }
+  } else {
+    outputs = property.outputs;
+  }
+  return outputs;
+}
+
 bool ShouldRestrictSameInputOutputScale(
     operator_property::OperatorProperty property) {
-  return (property.inputs.size() == 1 && property.outputs.size() == 1 &&
-          property.biases.empty() && property.restrict_same_input_output_scale);
+  // Ops with multiple inputs (i.e. concat) gets restricted in ApplyConstraints.
+  return (!property.arbitrary_inputs &&
+          property.restrict_same_input_output_scale);
 }
 
 bool IsSubgraphInput(SubGraphT* subgraph, int32_t index) {
@@ -541,8 +554,8 @@
     output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
     output_tensor->quantization->scale.push_back(input_scale);
     output_tensor->quantization->zero_point.push_back(input_zero_point);
-    output_tensor->quantization->min.push_back(min);
-    output_tensor->quantization->max.push_back(max);
+    output_tensor->quantization->min = {min};
+    output_tensor->quantization->max = {max};
     output_tensor->type = TensorType_INT8;
   } else if (tensor_property.restriction) {
     const auto scale_and_zp = tensor_property.restricted_value;
@@ -597,7 +610,7 @@
 
       // Quantize operator outputs.
       for (const std::pair<int, operator_property::TensorProperty>& output :
-           property.outputs) {
+           GetOutputs(op, property)) {
         TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
             model, subgraph_idx, op_idx, property, output, error_reporter));
       }
diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc
index ecf4deb..679b681 100644
--- a/tensorflow/lite/tools/optimize/quantize_model_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc
@@ -402,6 +402,72 @@
   EXPECT_EQ(model_.operator_codes[1]->version, 2);
 }
 
+class QuantizeSplitModelTest : public QuantizeModelTest {
+ protected:
+  QuantizeSplitModelTest() {
+    input_model_ = ReadModel(internal::kModelSplit);
+    readonly_model_ = input_model_->GetModel();
+    readonly_model_->UnPackTo(&model_);
+  }
+};
+
+// There are two outputs for split with different scales, the resulting model
+// should have the scales be hardcodes to the input scale value.
+TEST_F(QuantizeSplitModelTest, QuantizeSplit) {
+  auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
+                              TensorType_INT8, &error_reporter_);
+  EXPECT_EQ(status, kTfLiteOk);
+
+  // There is only one subgraph.
+  const int32_t subgraph_idx = 0;
+  const auto& subgraph = model_.subgraphs[subgraph_idx];
+  const auto& readonly_subgraph =
+      readonly_model_->subgraphs()->Get(subgraph_idx);
+
+  // There should be two ops: the split and add in the original model.
+  EXPECT_EQ(readonly_subgraph->operators()->size(), 2);
+  EXPECT_EQ(subgraph->operators.size(), 2);
+  const auto& split = subgraph->operators[0];
+  const auto& add = subgraph->operators[1];
+  EXPECT_EQ(model_.operator_codes[split->opcode_index]->builtin_code,
+            BuiltinOperator_SPLIT);
+  EXPECT_EQ(model_.operator_codes[add->opcode_index]->builtin_code,
+            BuiltinOperator_ADD);
+
+  // There should be 5 tensors: input, output, split, split/split_dim, split:1.
+  EXPECT_EQ(subgraph->tensors.size(), 5);
+
+  EXPECT_EQ(subgraph->tensors[0]->type, TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[0]->name, "input");
+  EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
+  EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
+  EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->scale[0], 1.0);
+  EXPECT_FLOAT_EQ(subgraph->tensors[0]->quantization->zero_point[0], -128);
+  EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[1]->name, "output");
+  EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 1);
+  EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 1);
+  EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->scale[0], 1.0);
+  EXPECT_FLOAT_EQ(subgraph->tensors[1]->quantization->zero_point[0], -128);
+  EXPECT_EQ(subgraph->tensors[2]->type, TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[2]->name, "split");
+  EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
+  EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
+  EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->scale[0], 1.0);
+  EXPECT_FLOAT_EQ(subgraph->tensors[2]->quantization->zero_point[0], -128);
+  EXPECT_EQ(subgraph->tensors[4]->type, TensorType_INT8);
+  EXPECT_EQ(subgraph->tensors[4]->name, "split:1");
+  EXPECT_EQ(subgraph->tensors[4]->quantization->scale.size(), 1);
+  EXPECT_EQ(subgraph->tensors[4]->quantization->zero_point.size(), 1);
+  EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->scale[0], 1.0);
+  EXPECT_FLOAT_EQ(subgraph->tensors[4]->quantization->zero_point[0], -128);
+
+  // check op and versioning.
+  EXPECT_EQ(model_.operator_codes.size(), 2);
+  EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_SPLIT);
+  EXPECT_EQ(model_.operator_codes[0]->version, 2);
+}
+
 class QuantizeConvModel1Test : public QuantizeModelTest {
  protected:
   QuantizeConvModel1Test() {
diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc
index 5f38d9a..3cfd5f5b 100644
--- a/tensorflow/lite/tools/optimize/test_util.cc
+++ b/tensorflow/lite/tools/optimize/test_util.cc
@@ -47,6 +47,8 @@
 
 const char* kModelMixed = "mixed.bin";
 
+const char* kModelSplit = "split.bin";
+
 int FailOnErrorReporter::Report(const char* format, va_list args) {
   char buf[1024];
   vsnprintf(buf, sizeof(buf), format, args);
diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h
index 1e7e14c..bf42a30b 100644
--- a/tensorflow/lite/tools/optimize/test_util.h
+++ b/tensorflow/lite/tools/optimize/test_util.h
@@ -73,6 +73,9 @@
 // reshape->custom->custom->squeeze.
 extern const char* kModelMixed;
 
+// Test model with split op.
+extern const char* kModelSplit;
+
 // An error reporter that fails on testing.
 class FailOnErrorReporter : public ErrorReporter {
  public:
diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/lite/tools/optimize/testdata/split.bin
new file mode 100644
index 0000000..3341df9
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/testdata/split.bin
Binary files differ