Accounts added dequantization operations into NNAPI model size

Before this CL application would crash when using NNAPI with target accelerator specified with model containing Conv2d or FullyConnected or LSTM nodes with quantized weights.
The NNAPI models generated by the NNAPI Delegate could contain extra Dequantize operations.  The crash is caused by the buffer passed to ANeuralNetworksModel_getSupportedOperationsForDevices being too small since those extra Dequantize operations were not accounted.

PiperOrigin-RevId: 314427031
Change-Id: Ie8bce2b63b3b6129942644f79c661ad0b01351ee
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index bdcb728..3eac83c 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -784,19 +784,36 @@
     return kTfLiteOk;
   }
 
+  // Adds the operation to the model and maps the operation to the originating
+  // TFLite one.
+  TfLiteStatus AddOperationToModel(ANeuralNetworksOperationType type,
+                                   uint32_t input_count, const uint32_t* inputs,
+                                   uint32_t output_count,
+                                   const uint32_t* outputs,
+                                   int lite_node_index) {
+    RETURN_TFLITE_ERROR_IF_NN_ERROR(
+        context_,
+        nnapi_->ANeuralNetworksModel_addOperation(
+            nn_model_, type, input_count, inputs, output_count, outputs),
+        "adding operation", nnapi_errno_);
+    nnapi_to_tflite_op_mapping_->push_back(lite_node_index);
+    return kTfLiteOk;
+  }
+
   // Adds a Dequantize operator and replaces the input tensor index with the
   // dequantized version. If the dequantized version of the operator already
   // exists then it is not added again.
-  TfLiteStatus AddDequantize(int nn_input_index, int lite_index,
-                             TfLiteType dequantized_type) {
-    const int ann_index = operand_mapping_->lite_index_to_ann(lite_index);
+  TfLiteStatus AddDequantize(int nn_input_index, int lite_tensor_index,
+                             TfLiteType dequantized_type, int lite_node_index) {
+    const int ann_index =
+        operand_mapping_->lite_index_to_ann(lite_tensor_index);
     int dequantized_ann_index =
         dequantize_mapping_->DequantizedAnnIndex(ann_index, dequantized_type);
 
     if (dequantized_ann_index == -1) {
       // The dequantized version does not exist yet, it has to be added: a new
       // Dequantize operation is added, yielding a new tensor.
-      const TfLiteTensor& tensor = context_->tensors[lite_index];
+      const TfLiteTensor& tensor = context_->tensors[lite_tensor_index];
       ANeuralNetworksOperandType operand_type{
           ANEURALNETWORKS_TENSOR_FLOAT32,
           static_cast<uint32_t>(tensor.dims->size),
@@ -811,12 +828,11 @@
       const uint32_t dequantize_input[1] = {static_cast<uint32_t>(ann_index)};
       const uint32_t dequantize_output[1] = {
           static_cast<uint32_t>(dequantized_ann_index)};
-      RETURN_TFLITE_ERROR_IF_NN_ERROR(
-          context_,
-          nnapi_->ANeuralNetworksModel_addOperation(
-              nn_model_, ANEURALNETWORKS_DEQUANTIZE, 1, dequantize_input, 1,
-              dequantize_output),
-          "adding operation", nnapi_errno_);
+      TF_LITE_ENSURE_OK(
+          context_, AddOperationToModel(ANEURALNETWORKS_DEQUANTIZE,
+                                        /*input_count=*/1, dequantize_input,
+                                        /*output_count=*/1, dequantize_output,
+                                        lite_node_index));
       dequantize_mapping_->Add(ann_index, dequantized_type,
                                dequantized_ann_index);
     }
@@ -832,15 +848,12 @@
   TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type,
                                     int lite_node_index) {
     // Actually add a NN API operation
-    RETURN_TFLITE_ERROR_IF_NN_ERROR(
-        context_,
-        nnapi_->ANeuralNetworksModel_addOperation(
-            nn_model_, type, static_cast<uint32_t>(augmented_inputs_.size()),
-            augmented_inputs_.data(),
-            static_cast<uint32_t>(augmented_outputs_.size()),
-            augmented_outputs_.data()),
-        "adding operation", nnapi_errno_);
-    nnapi_to_tflite_op_mapping_->push_back(lite_node_index);
+    TF_LITE_ENSURE_OK(context_,
+                      AddOperationToModel(
+                          type, static_cast<uint32_t>(augmented_inputs_.size()),
+                          augmented_inputs_.data(),
+                          static_cast<uint32_t>(augmented_outputs_.size()),
+                          augmented_outputs_.data(), lite_node_index));
     augmented_inputs_.clear();
     augmented_outputs_.clear();
     return kTfLiteOk;
@@ -3610,7 +3623,7 @@
 
 void NNAPIDelegateKernel::AddDequantizeOperatorsWhereNeeded(
     const TfLiteContext* context, int builtin_code, const TfLiteNode* node,
-    NNAPIOpBuilder* builder, int* nnapi_errno) {
+    int tflite_node_index, NNAPIOpBuilder* builder, int* nnapi_errno) {
   // Depending on the operator and the input data format, Dequantize
   // operators may need to be added. For example when the input is
   // floating-point but weights are quantized then the weights will first be
@@ -3658,7 +3671,7 @@
 
     // Insert Dequantize operator if it hasn't been done already and change
     // the node's input accordingly.
-    builder->AddDequantize(i, node->inputs->data[i], type);
+    builder->AddDequantize(i, node->inputs->data[i], type, tflite_node_index);
   }
 }
 
@@ -4018,7 +4031,7 @@
     // Dequantize operators may have to be added in case inputs are to be
     // floating-point.
     AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node,
-                                      &builder, nnapi_errno);
+                                      node_index, &builder, nnapi_errno);
 
     builder.FinalizeAddOperation(nn_op_type, node_index);
   }
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
index c89e7fd..9a0d13a 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc
@@ -16,6 +16,7 @@
 
 #include <algorithm>
 #include <array>
+#include <cstdint>
 #include <iterator>
 #include <memory>
 #include <numeric>
@@ -568,12 +569,9 @@
 
 // This is a model with two ops:
 //
-//  input1 ---->
-//                ADD --
-//  input2   -->        |
-//                       -->
-//                          SUB --> output
-//  input3 ---------------->
+//  input1 ----> HARD_SWISH ---->
+//                                ADD --> output
+//  input2 ---------------------->
 //
 class HardSwishAddOpsAcceleratedModel : public MultiOpModel,
                                         public AcceleratedModel {
@@ -714,6 +712,119 @@
   ASSERT_EQ(m.CountOpsExecutedByCpuKernel(), 0);
 }
 
+class QuantizedWeightsConvolutionOpModel : public SingleOpModel,
+                                           public AcceleratedModel {
+ public:
+  QuantizedWeightsConvolutionOpModel(
+      const NnApi* nnapi, std::string accelerator_name, const TensorData& input,
+      const TensorData& filter, const TensorData& output, int stride_width = 2,
+      int stride_height = 2, enum Padding padding = Padding_VALID,
+      enum ActivationFunctionType activation = ActivationFunctionType_NONE,
+      int dilation_width_factor = 1, int dilation_height_factor = 1,
+      int num_threads = -1, std::initializer_list<uint8_t> filter_data = {})
+      : SingleOpModel(), AcceleratedModel(nnapi, accelerator_name) {
+    auto* delegate = GetDelegate();
+    this->SetApplyDelegate([delegate](Interpreter* interpreter) {
+      interpreter->ModifyGraphWithDelegate(delegate);
+    });
+
+    input_ = AddInput(input);
+
+    if (filter_data.size()) {
+      filter_ = AddConstInput(filter, filter_data);
+    } else {
+      filter_ = AddInput(filter);
+    }
+
+    int bias_size = GetShape(filter_)[0];
+
+    bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+
+    output_ = AddOutput(output);
+
+    SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
+                 CreateConv2DOptions(
+                     builder_, padding, stride_width, stride_height, activation,
+                     dilation_width_factor, dilation_height_factor)
+                     .Union());
+
+    BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)},
+                     num_threads);
+  }
+
+  void SetInput(std::initializer_list<float> data) {
+    PopulateTensor(input_, data);
+  }
+
+  void SetFilter(std::initializer_list<float> data) {
+    QuantizeAndPopulate<uint8_t>(filter_, data);
+  }
+
+  void SetBias(std::initializer_list<float> data) {
+    PopulateTensor(input_, data);
+  }
+
+  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+  std::vector<float> GetDequantizedOutput() {
+    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+                               GetScale(output_), GetZeroPoint(output_));
+  }
+
+ protected:
+  int input_;
+  int filter_;
+  int bias_;
+  int output_;
+};
+
+int quantized_conv2d_model_added_nnapi_ops_count = 0;
+TEST_F(TfLiteOpMappedToMultipleNnApiOps,
+       AddedDequantizationsAreAccountedInModelOps) {
+  nnapi_mock_->ModelCreateReturns<0>();
+  nnapi_mock_->StubGetSupportedOperationsForDevicesWith(
+      [](const ANeuralNetworksModel* model,
+         const ANeuralNetworksDevice* const* devices, uint32_t numDevices,
+         bool* supportedOps) -> int {
+        std::fill(supportedOps,
+                  supportedOps + quantized_conv2d_model_added_nnapi_ops_count,
+                  true);
+        return ANEURALNETWORKS_NO_ERROR;
+      });
+  nnapi_mock_->StubAddOperationWith(
+      [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
+         uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+         const uint32_t* outputs) -> int {
+        ++quantized_conv2d_model_added_nnapi_ops_count;
+        return ANEURALNETWORKS_NO_ERROR;
+      });
+
+  QuantizedWeightsConvolutionOpModel m(
+      nnapi_mock_->GetNnApi(),
+      /*accelerator_name=*/"test-device", {TensorType_FLOAT32, {2, 2, 4, 1}},
+      {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, {TensorType_FLOAT32, {}});
+  m.SetInput({
+      // First batch
+      1, 1, 1, 1,  // row = 1
+      2, 2, 2, 2,  // row = 2
+      // Second batch
+      1, 2, 3, 4,  // row = 1
+      1, 2, 3, 4,  // row = 2
+  });
+  m.SetFilter({
+      1, 2, 3, 4,    // first 2x2 filter
+      -1, 1, -1, 1,  // second 2x2 filter
+      -1, -1, 1, 1,  // third 2x2 filter
+  });
+  m.SetBias({1, 2, 3});
+
+  EXPECT_EQ(m.CountOpsExecutedByCpuKernel(), 0);
+  // When delegating quantized Conv2D, for each quantized inputs a
+  // dequantize operation is added to the model.
+  // In our case 1 Dequantize op for the weights is expected generating
+  // a 2 ops model.
+  EXPECT_EQ(quantized_conv2d_model_added_nnapi_ops_count, 2);
+}
+
 // Model with a chain of no-op (add with zero operations)
 // interleaved with no-op custom nodes.
 class LongIdentityModel : public MultiOpModel, public AcceleratedModel {
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
index af93d96..26822c0 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
@@ -341,11 +341,9 @@
 
   std::vector<int> nnapi_to_tflite_op_mapping_;
 
-  void AddDequantizeOperatorsWhereNeeded(const TfLiteContext* context,
-                                         int builtin_code,
-                                         const TfLiteNode* node,
-                                         NNAPIOpBuilder* builder,
-                                         int* nnapi_errno);
+  void AddDequantizeOperatorsWhereNeeded(
+      const TfLiteContext* context, int builtin_code, const TfLiteNode* node,
+      int tflite_node_index, NNAPIOpBuilder* builder, int* nnapi_errno);
 
   TfLiteStatus AddOpsAndTensors(TfLiteContext* context, int* nnapi_errno);