Add option to revert to previous hybrid quantization scheme

PiperOrigin-RevId: 322886366
Change-Id: I7665938cfecd27e1be09ccfdd145ec02eb03d675
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc
index 8bef019..e4840ae 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights.cc
@@ -130,7 +130,8 @@
 
 // Returns true if the operator supports hybrid evaluation.
 bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
-                          const CustomOpMap& custom_op_map) {
+                          const CustomOpMap& custom_op_map,
+                          bool use_updated_hybrid_scheme) {
   const BuiltinOperator builtin_op_code = op_code->builtin_code;
   // Operations that support hybrid evaluation.
   bool eval_hybrid = false;
@@ -144,7 +145,6 @@
     }
   } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
              builtin_op_code == BuiltinOperator_CONV_2D ||
-             builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
              builtin_op_code == BuiltinOperator_SVDF ||
              builtin_op_code == BuiltinOperator_RNN ||
              builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
@@ -158,6 +158,8 @@
     if (options->kernel_type == LSTMKernelType_FULL) {
       eval_hybrid = true;
     }
+  } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
+    eval_hybrid = use_updated_hybrid_scheme;
   }
   return eval_hybrid;
 }
@@ -191,7 +193,7 @@
     const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements,
     const CustomOpMap& custom_op_map,
     absl::flat_hash_map<int32_t, TensorPerChannel>* tensor_map,
-    int subgraph_index) {
+    int subgraph_index, bool use_updated_hybrid_scheme) {
   SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
   const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
 
@@ -231,43 +233,46 @@
     }
 
     if (op_code->builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
-      tensor_map->insert(
-          {tensor_idx, {tensor, /*is_per_channel=*/true, /*dim=*/3}});
+      tensor_map->insert({tensor_idx,
+                          {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
+                           /*dim=*/3}});
     } else if (op_code->builtin_code == BuiltinOperator_CONV_2D) {
-      tensor_map->insert(
-          {tensor_idx, {tensor, /*is_per_channel=*/true, /*dim=*/0}});
+      tensor_map->insert({tensor_idx,
+                          {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
+                           /*dim=*/0}});
     } else {
       switch (op_code->builtin_code) {
         case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
           op->builtin_options.AsBidirectionalSequenceLSTMOptions()
-              ->asymmetric_quantize_inputs = true;
+              ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
           op->builtin_options.AsBidirectionalSequenceRNNOptions()
-              ->asymmetric_quantize_inputs = true;
+              ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_FULLY_CONNECTED:
           op->builtin_options.AsFullyConnectedOptions()
-              ->asymmetric_quantize_inputs = true;
+              ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_LSTM:
           op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
-              true;
+              use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_RNN:
-          op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = true;
+          op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs =
+              use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_SVDF:
           op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs =
-              true;
+              use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
           op->builtin_options.AsUnidirectionalSequenceLSTMOptions()
-              ->asymmetric_quantize_inputs = true;
+              ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
           break;
         case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
           op->builtin_options.AsSequenceRNNOptions()
-              ->asymmetric_quantize_inputs = true;
+              ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
           break;
         default:
           break;
@@ -323,25 +328,27 @@
 }
 
 // Updates operator code versions for the operators with INT8 inputs.
-void UpdateInt8OperatorVersions(ModelT* model) {
+void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) {
   for (int i = 0; i < model->operator_codes.size(); ++i) {
     const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
-    if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+    if (op_code == BuiltinOperator_RNN ||
         op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
-        op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
-        op_code == BuiltinOperator_RNN ||
         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2;
+    } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+               op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
       model->operator_codes[i]->version = 3;
-    } else if (op_code == BuiltinOperator_LSTM ||
-               op_code == BuiltinOperator_SVDF) {
-      model->operator_codes[i]->version = 4;
+    } else if (op_code == BuiltinOperator_LSTM) {
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3;
     } else if (op_code == BuiltinOperator_CONV_2D) {
-      model->operator_codes[i]->version = 5;
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2;
+    } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3;
+    } else if (op_code == BuiltinOperator_SVDF) {
+      model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2;
     } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
       model->operator_codes[i]->version = 6;
-    } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
-      model->operator_codes[i]->version = 9;
     }
   }
 }
@@ -402,7 +409,8 @@
                                  const Model* input_model,
                                  bool use_hybrid_evaluation,
                                  uint64_t weights_min_num_elements,
-                                 const CustomOpMap& custom_op_map) {
+                                 const CustomOpMap& custom_op_map,
+                                 bool use_updated_hybrid_scheme) {
   std::unique_ptr<ModelT> model;
   model.reset(input_model->UnPack());
 
@@ -415,7 +423,7 @@
       OperatorT* op = subgraph->operators[i].get();
       TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
           model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map,
-          subgraph_index));
+          subgraph_index, use_updated_hybrid_scheme));
     }
 
     for (std::pair<int32_t, TensorPerChannel> tensor_pair : tensor_map) {
@@ -456,8 +464,8 @@
         // dequantization we need to add a Dequantize op.
         bool eval_hybrid =
             use_hybrid_evaluation &&
-            IsHybridEvaluationOp(consumer_op, consumer_op_code,
-                                 custom_op_map) &&
+            IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map,
+                                 use_updated_hybrid_scheme) &&
             CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code,
                                       custom_op_map) &&
             IsQuantizedInput(consumer_op_code, custom_op_map,
@@ -516,7 +524,7 @@
   }
 
   // Update the modified operator code versions.
-  UpdateInt8OperatorVersions(model.get());
+  UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme);
 
   flatbuffers::Offset<Model> output_model_location =
       Model::Pack(*builder, model.get());
@@ -611,7 +619,8 @@
   // kWeightsMinSizeDefault elements are quantized.
   CustomOpMap custom_op_map;
   return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
-                             weights_min_num_elements, custom_op_map);
+                             weights_min_num_elements, custom_op_map,
+                             kUseUpdatedHybridSchemeDefault);
 }
 }  // namespace internal
 
@@ -620,7 +629,8 @@
                              uint64_t weights_min_num_elements) {
   CustomOpMap custom_op_map;
   return QuantizeWeightsInt8(builder, input_model, true,
-                             weights_min_num_elements, custom_op_map);
+                             weights_min_num_elements, custom_op_map,
+                             kUseUpdatedHybridSchemeDefault);
 }
 
 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
@@ -631,7 +641,8 @@
       // kWeightsMinSizeDefault elements are quantized.
       CustomOpMap custom_op_map;
       return QuantizeWeightsInt8(builder, input_model, true,
-                                 kWeightsMinNumElementsDefault, custom_op_map);
+                                 kWeightsMinNumElementsDefault, custom_op_map,
+                                 kUseUpdatedHybridSchemeDefault);
     }
     case BufferType::QUANTIZED_FLOAT16:
       return QuantizeWeightsFloat16(builder, input_model);
@@ -643,7 +654,19 @@
                              uint64_t weights_min_num_elements,
                              const CustomOpMap& custom_op_map) {
   return QuantizeWeightsInt8(builder, input_model, true,
-                             weights_min_num_elements, custom_op_map);
+                             weights_min_num_elements, custom_op_map,
+                             kUseUpdatedHybridSchemeDefault);
+}
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+                             const Model* input_model,
+                             uint64_t weights_min_num_elements,
+                             const CustomOpMap& custom_op_map,
+                             bool use_updated_hybrid_scheme) {
+  return QuantizeWeightsInt8(builder, input_model,
+                             /*use_hybrid_evaluation=*/true,
+                             weights_min_num_elements, custom_op_map,
+                             use_updated_hybrid_scheme);
 }
 
 }  // namespace optimize
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.h b/tensorflow/lite/tools/optimize/quantize_weights.h
index 528614f..9212c9a 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.h
+++ b/tensorflow/lite/tools/optimize/quantize_weights.h
@@ -29,6 +29,13 @@
 // Supported resulting types from quantization process.
 enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 };
 
+// This macro is for internal use for conversions requiring previous behavior.
+#ifdef TFLITE_USE_PREVIOUS_HYBRID_SCHEME
+constexpr bool kUseUpdatedHybridSchemeDefault = false;
+#else
+constexpr bool kUseUpdatedHybridSchemeDefault = true;
+#endif
+
 // Quantizes input_model and populates the provided builder with the new model.
 // By default only weights tensors weight more than 1024 elements will be
 // quantized.
@@ -61,6 +68,14 @@
                              uint64_t weights_min_num_elements,
                              const CustomOpMap& custom_op_map);
 
+// Same as above, but if use updated_hybrid_scheme is false,
+// use previous quantization scheme.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+                             const Model* input_model,
+                             uint64_t weights_min_num_elements,
+                             const CustomOpMap& custom_op_map,
+                             bool use_updated_hybrid_scheme);
+
 namespace internal {
 // If use_hybrid_evaluation is false, will disable using hybrid eval for
 // operations that support it.
diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc
index 2f92a9a..94bff2d 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc
@@ -216,7 +216,11 @@
         EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
             << quant_tensor->name()->str();
         auto shape = GetAsVector(quant_tensor->shape());
-        EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
+        if (kUseUpdatedHybridSchemeDefault) {
+          EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
+        } else {
+          EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
+        }
       } else {
         EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
       }
@@ -533,6 +537,58 @@
   EXPECT_EQ(num_custom_ops_found, 1);
 }
 
+TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) {
+  LoadBasicModel();
+  flatbuffers::FlatBufferBuilder builder;
+  const CustomOpMap custom_op_map;
+  auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, false);
+  EXPECT_EQ(status, kTfLiteOk);
+
+  const uint8_t* buffer = builder.GetBufferPointer();
+  const Model* output_model = GetModel(buffer);
+  ASSERT_TRUE(output_model);
+
+  // Nothing should change.
+  ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
+  for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
+       subgraph_idx++) {
+    const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
+    const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
+    ASSERT_EQ(quantized_graph->tensors()->size(),
+              float_graph->tensors()->size());
+    // Make sure the graph only has one Conv operation.
+    ASSERT_EQ(quantized_graph->operators()->size(), 1);
+    const auto op = quantized_graph->operators()->Get(0);
+    const uint32_t op_code_idx = op->opcode_index();
+    ASSERT_EQ(output_model->operator_codes()->Get(op_code_idx)->builtin_code(),
+              BuiltinOperator_CONV_2D);
+    for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
+      const auto quant_tensor = quantized_graph->tensors()->Get(i);
+      const auto float_tensor = float_graph->tensors()->Get(i);
+      EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
+      EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
+      EXPECT_EQ(GetAsVector(quant_tensor->shape()),
+                GetAsVector(float_tensor->shape()));
+      EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
+      // If the tensor is a weight, it should have type INT8, otherwise it
+      // should stay with type FLOAT32.
+      // If the tensor is a bias, it should have type FLOAT32.
+      if (quant_tensor->name()->str() == "conv_bias") {
+        EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+      } else if (IsModelInputOrOutput(output_model, i)) {
+        EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+      } else if (quant_tensor->buffer() != 0) {
+        EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
+            << quant_tensor->name()->str();
+        auto shape = GetAsVector(quant_tensor->shape());
+        EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
+      } else {
+        EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+      }
+    }
+  }
+}
+
 }  // namespace
 }  // namespace optimize
 }  // namespace tflite