/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"

#include <sys/mman.h>

#include <algorithm>
#include <functional>
#include <initializer_list>
#include <memory>

#include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_plugin.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
#include "tensorflow/lite/nnapi/nnapi_implementation.h"

namespace tflite {
namespace {

using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::FloatNear;
using ::testing::Matcher;

// TODO(b/110368244): figure out how to share the existing tests in kernels/ but
// with the delegation on. Also, add more unit tests to improve code coverage.

// This matcher uses 1 as maximum tolerance.
MATCHER(QuantizedNear, "") {
  const int diff = abs(std::get<0>(arg) - std::get<1>(arg));
  if (diff > 1) {
    *result_listener << "Quantized values can be at most off by one: " << diff;
    return false;
  }
  return true;
}

auto NnapiArrayFloatNear(const std::vector<float>& values,
                         bool relaxed = false) {
  // Uses the same tolerance as NNAPI generated tests.
  const float atol = relaxed ? 5 * 0.0009765625f : 1e-5f;
  const float rtol = relaxed ? 5 * 0.0009765625f : 5 * 1.1920928955078125e-7f;

  std::vector<Matcher<float>> matchers;
  matchers.reserve(values.size());
  for (const float& v : values) {
    const float tolerance = atol + rtol * std::abs(v);
    matchers.emplace_back(FloatNear(v, tolerance));
  }
  return ElementsAreArray(matchers);
}

class SingleOpModelWithNNAPI : public SingleOpModel {
 public:
  SingleOpModelWithNNAPI() { options_.disallow_nnapi_cpu = false; }

  explicit SingleOpModelWithNNAPI(
      const StatefulNnApiDelegate::Options& options) {
    options_ = options;
    options_.disallow_nnapi_cpu = false;
  }

  TfLiteStatus ResizeInputTensor(int tensor_index,
                                 const std::vector<int>& dims) {
    return interpreter_->ResizeInputTensor(tensor_index, dims);
  }

  StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }

  void SetBufferHandle(int index, TfLiteBufferHandle handle) {
    interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
  }

  void MarkInputTensorDataStale(int index) {
    interpreter_->tensor(index)->data_is_stale = true;
  }

  TfLiteStatus AllocateTensors() { return interpreter_->AllocateTensors(); }

  void SetTensorMaxSize(uint32_t tensor_index, size_t max_size) {
    options_.tensor_max_size_hints.emplace(tensor_index, max_size);
  }

  void ApplyNNAPIDelegate() {
    stateful_delegate_ = std::make_unique<StatefulNnApiDelegate>(options_);
    SetDelegate(stateful_delegate_.get());
    ApplyDelegate();
  }

 protected:
  void SetData(int index, TensorType type, const std::vector<float>& data) {
    switch (type) {
      case TensorType_FLOAT32:
        PopulateTensor(index, data);
        break;
      case TensorType_INT32:
        QuantizeAndPopulate<int32_t>(index, data);
        break;
      case TensorType_UINT8:
        QuantizeAndPopulate<uint8_t>(index, data);
        break;
      case TensorType_INT8:
        QuantizeAndPopulate<int8_t>(index, data);
        break;
      default:
        FAIL() << "Type not supported: " << type;
        break;
    }
  }

  void GetData(int index, TensorType type, std::vector<float>* output) {
    switch (type) {
      case TensorType_FLOAT32:
        *output = ExtractVector<float>(index);
        break;
      case TensorType_UINT8:
        *output = Dequantize<uint8_t>(ExtractVector<uint8_t>(index),
                                      GetScale(index), GetZeroPoint(index));
        break;
      default:
        FAIL() << "Type not supported: " << type;
        break;
    }
  }

  void BuildInterpreterWithNNAPI(std::vector<std::vector<int>> input_shapes,
                                 bool allow_fp32_relax_to_fp16 = false,
                                 bool apply_delegate = true) {
    // We skip those TfLite delegates that are applied by default in TfLite
    // runtime by setting 'apply_delegate' to false. Afterwards, we explicitly
    // call ApplyDelegate to apply the NNAPI delegate to meet the testing
    // purpose.
    BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
                     /*apply_delegate=*/false, /*allocate_and_delegate=*/true);
    if (apply_delegate) {
      ApplyNNAPIDelegate();
    }
  }

 private:
  // Stateful NNAPI delegate. This is valid only if the state-ful constructor is
  // used.
  StatefulNnApiDelegate::Options options_;
  std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
};

class FloatAddOpModel : public SingleOpModelWithNNAPI {
 public:
  FloatAddOpModel(const TensorData& input1, const TensorData& input2,
                  const TensorData& output,
                  ActivationFunctionType activation_type,
                  bool allow_fp32_relax_to_fp16 = false) {
    Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
  }

  FloatAddOpModel(const StatefulNnApiDelegate::Options& options,
                  const TensorData& input1, const TensorData& input2,
                  const TensorData& output,
                  ActivationFunctionType activation_type,
                  bool allow_fp32_relax_to_fp16 = false)
      : SingleOpModelWithNNAPI(options) {
    Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
  }

  int input1() { return input1_; }
  int input2() { return input2_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input1_;
  int input2_;
  int output_;

 private:
  // Performs initialization logic shared across all constructors.
  void Init(const TensorData& input1, const TensorData& input2,
            const TensorData& output, ActivationFunctionType activation_type,
            bool allow_fp32_relax_to_fp16 = false) {
    input1_ = AddInput(input1);
    input2_ = AddInput(input2);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
                 CreateAddOptions(builder_, activation_type).Union());
    BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)},
                              allow_fp32_relax_to_fp16);
  }
};

// Do a test with the NN API using no activation.
TEST(NNAPIDelegate, AddWithNoActivation) {
  FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Do a test with scalar input using no activation.
TEST(NNAPIDelegate, AddScalarWithNoActivation) {
  FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
                    ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.7});
  m.PopulateTensor<float>(m.input2(), {0.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.3, 0.8, 0.8}));
}

// Do a test with the NN API using no activation.
// The test allows computing FP32 with FP16 precision. In this particular case,
// calculating in FP32 or FP16 should produce the same results.
TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
  FloatAddOpModel m(
      {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
      {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
  m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
  m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-1.0, 1.0, 4.0, 6.0}, /*relaxed=*/true));
}

// Do a test with the NN api with relu.
TEST(NNAPIDelegate, AddWithRelu) {
  FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0.0, 0.4, 1.0, 1.3}));
}

// Verify that resize attempts succeed.
TEST(NNAPIDelegate, ResizeInputTensorsWorks) {
  FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);

  EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));

  EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
  m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
  m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1.0, 1.3, 1.1, 1.5}));
}

TEST(NNAPIDelegate, ResizeDynamicBatchInputTensorsWorks) {
  StatefulNnApiDelegate::Options options;
  options.allow_dynamic_dimensions = true;
  options.max_execution_cache_size = 1;

  FloatAddOpModel m(options,
                    {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
                     /*max=*/0.0f, /*scale=*/0.0f,
                     /*zero_point=*/0, /*per_channel_quantization=*/false,
                     /*per_channel_quantization_scales=*/{},
                     /*per_channel_quantization_offsets=*/{},
                     /*channel_index=*/0, /*traversal_order=*/{},
                     /*format=*/{},
                     /*block_size=*/{}, /*block_map=*/{},
                     /*shape_signature=*/{1, -1, 2, 1}},
                    {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
                     /*max=*/0.0f, /*scale=*/0.0f,
                     /*zero_point=*/0, /*per_channel_quantization=*/false,
                     /*per_channel_quantization_scales=*/{},
                     /*per_channel_quantization_offsets=*/{},
                     /*channel_index=*/0, /*traversal_order=*/{},
                     /*format=*/{},
                     /*block_size=*/{}, /*block_map=*/{},
                     /*shape_signature=*/{1, -1, 2, 1}},
                    {TensorType_FLOAT32, /*shape=*/{}, /*min=*/0.0f,
                     /*max=*/0.0f, /*scale=*/0.0f,
                     /*zero_point=*/0, /*per_channel_quantization=*/false,
                     /*per_channel_quantization_scales=*/{},
                     /*per_channel_quantization_offsets=*/{},
                     /*channel_index=*/0, /*traversal_order=*/{},
                     /*format=*/{},
                     /*block_size=*/{}, /*block_map=*/{},
                     /*shape_signature=*/{1, -1, 2, 1}},
                    ActivationFunctionType_NONE);

  // Define 2 test cases, each with a different dynamic dimension value.
  auto RunTestCase1 = [&m]() {
    EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
    EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
    EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
    m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
    m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
    ASSERT_EQ(m.Invoke(), kTfLiteOk);
    EXPECT_THAT(m.GetOutput(),
                ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));
  };
  auto RunTestCase2 = [&m]() {
    EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
    EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
    EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
    m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
    m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
    ASSERT_EQ(m.Invoke(), kTfLiteOk);
    EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.0, 1.3, 1.1, 1.5}));
  };

  // TODO(b/221070667): Find a way to test whether the execution has indeed been
  // reused or not.
  // This will create a new execution for case 1.
  RunTestCase1();
  // This will reuse the execution for case 1.
  RunTestCase1();
  // This will destroy case 1, and create a new execution for case 2.
  RunTestCase2();
  // This will destroy case 2, and create a new execution for case 1.
  RunTestCase1();
}

// Sanity check for the state-ful NNAPI delegate.
TEST(NNAPIDelegate, StatefulDelegate) {
  StatefulNnApiDelegate::Options options;
  options.execution_preference =
      StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;

  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Sanity check for the state-ful NNAPI delegate with accelerator_name
// specified.
TEST(NNAPIDelegate, StatefulDelegateWithAcceleratorName) {
  StatefulNnApiDelegate::Options options;
  options.execution_preference =
      StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
  options.accelerator_name = "nnapi-reference";

  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Sanity check for the state-ful NNAPI delegate with invalid accelerator_name
// specified.
TEST(NNAPIDelegate, StatefulDelegateWithInvalidAcceleratorName) {
  if (!NnApiImplementation()->ANeuralNetworksDevice_getName) {
    GTEST_SKIP();
  }
  testing::internal::CaptureStderr();
  StatefulNnApiDelegate::Options options;
  options.execution_preference =
      StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
  options.accelerator_name = "foo";

  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  EXPECT_THAT(testing::internal::GetCapturedStderr(),
              testing::HasSubstr(
                  "Could not find the specified NNAPI accelerator: foo"));

  // Execution should fall back to the default CPU path.
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Sanity check for the state-ful NNAPI delegate with compilation caching
// enabled.
TEST(NNAPIDelegate, StatefulDelegateWithCompilationCaching) {
  StatefulNnApiDelegate::Options options;
  options.execution_preference =
      StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
  options.cache_dir = "/data/local/tmp";
  options.model_token = "NNAPIDelegate.StatefulDelegateWithCompilationCaching";

  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Sanity check for the state-ful NNAPI delegate with QoS hints.
TEST(NNAPIDelegate, StatefulDelegateWithQoS) {
  StatefulNnApiDelegate::Options options;
  options.accelerator_name = "nnapi-reference";
  options.execution_priority = ANEURALNETWORKS_PRIORITY_HIGH;
  options.max_compilation_timeout_duration_ns = UINT64_MAX;
  options.max_execution_timeout_duration_ns = UINT64_MAX;
  options.max_execution_loop_timeout_duration_ns = UINT64_MAX;

  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
}

// Sanity check for the state-ful NNAPI delegate using TfLiteBufferHandle.
TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
  // Skip the test if Android specific functions could not be found.
  if (!NnApiImplementation()->ASharedMemory_create ||
      !NnApiImplementation()->ANeuralNetworksMemory_createFromFd) {
    GTEST_SKIP();
  }

  StatefulNnApiDelegate::Options options;
  // Allow NNAPI CPU fallback path.
  options.disallow_nnapi_cpu = false;
  options.max_execution_cache_size = 2;
  FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  auto* delegate = m.GetDelegate();
  // Create ASharedMemory and copy data into it.
  constexpr auto kInput1ByteSize = 4 * sizeof(float);
  ANeuralNetworksMemory* input1_memory = nullptr;
  int fd =
      NnApiImplementation()->ASharedMemory_create("input1", kInput1ByteSize);
  EXPECT_GE(fd, 0);
  void* input1_memory_data =
      mmap(nullptr, kInput1ByteSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
  EXPECT_TRUE(input1_memory_data != nullptr);
  float input1_data[] = {-2.0, 0.2, 0.7, 0.8};
  memcpy(input1_memory_data, input1_data, kInput1ByteSize);
  int result = NnApiImplementation()->ANeuralNetworksMemory_createFromFd(
      kInput1ByteSize, PROT_READ, fd, 0, &input1_memory);
  EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
  ASSERT_NE(input1_memory, nullptr);

  struct DummyMemoryContext {
    ANeuralNetworksMemory* memory_handle;
    void* memory_data;
    size_t byte_size;
  };
  DummyMemoryContext memory_context = {input1_memory, input1_memory_data,
                                       kInput1ByteSize};
  static StatefulNnApiDelegate::CopyToHostTensorFnPtr memory_callback =
      [](TfLiteTensor* tensor, ANeuralNetworksMemory* memory,
         size_t memory_offset, size_t byte_size,
         void* callback_context) -> TfLiteStatus {
    auto memory_context =
        reinterpret_cast<DummyMemoryContext*>(callback_context);
    if (memory != memory_context->memory_handle ||
        memory_offset + byte_size > memory_context->byte_size) {
      return kTfLiteError;
    }
    memcpy(
        tensor->data.raw,
        reinterpret_cast<uint8_t*>(memory_context->memory_data) + memory_offset,
        byte_size);
    return kTfLiteOk;
  };
  auto input1_handle = delegate->RegisterNnapiMemory(
      input1_memory, memory_callback, &memory_context);
  m.SetBufferHandle(m.input1(), input1_handle);
  m.MarkInputTensorDataStale(m.input1());
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));

  // Run the inference multiple times with the same buffer so that the execution
  // can be reused.
  for (int i = 0; i < 10; i++) {
    // Change the value a little bit.
    input1_data[0] = -2.0 + i;
    memcpy(input1_memory_data, input1_data, kInput1ByteSize);
    m.MarkInputTensorDataStale(m.input1());
    ASSERT_EQ(m.Invoke(), kTfLiteOk);
    EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9f + i, 0.4f, 1.0f, 1.3f}));
  }

  // Run the inference multiple times and each time register a buffer.
  // Each will destroy the previous cache and create a new execution.
  for (int i = 0; i < 10; i++) {
    // Change the value a little bit.
    input1_data[0] = -2.0 + i;
    memcpy(input1_memory_data, input1_data, kInput1ByteSize);
    auto input1_handle = delegate->RegisterNnapiMemory(
        input1_memory, memory_callback, &memory_context);
    m.SetBufferHandle(m.input1(), input1_handle);
    m.MarkInputTensorDataStale(m.input1());
    m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
    ASSERT_EQ(m.Invoke(), kTfLiteOk);
    EXPECT_THAT(m.GetOutput(),
                NnapiArrayFloatNear({-1.9f + i, 0.4f, 1.0f, 1.3f}));
  }
}

class FloatMulOpModel : public SingleOpModelWithNNAPI {
 public:
  FloatMulOpModel(const TensorData& input1, const TensorData& input2,
                  const TensorData& output,
                  ActivationFunctionType activation_type) {
    input1_ = AddInput(input1);
    input2_ = AddInput(input2);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
                 CreateMulOptions(builder_, activation_type).Union());
    BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
  }

  int input1() { return input1_; }
  int input2() { return input2_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input1_;
  int input2_;
  int output_;
};

TEST(NNAPIDelegate, MulWithNoActivation) {
  FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-0.2, 0.04, 0.21, 0.4}));
}

class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
 public:
  FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
                      int filter_width, int filter_height,
                      const TensorData& output) {
    input_ = AddInput(input);
    output_ = AddOutput(output);

    SetBuiltinOp(
        type, BuiltinOptions_Pool2DOptions,
        CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
                            filter_height, ActivationFunctionType_NONE)
            .Union());

    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor(input_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input_;
  int output_;
};

TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
  FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
                        /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
                        /*filter_width=*/2, /*filter_height=*/2,
                        /*output=*/{TensorType_FLOAT32, {}});
  m.SetInput({
      0, 6, 2, 4,   //
      3, 2, 10, 7,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2.75, 5.75}));
}

TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
  FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
                        /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
                        /*filter_width=*/2, /*filter_height=*/2,
                        /*output=*/{TensorType_FLOAT32, {}});
  m.SetInput({
      0, 6, 2, 4,   //
      3, 2, 10, 7,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({6, 10}));
}

TEST(NNAPIDelegate, L2PoolWithNoActivation) {
  FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
                        /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
                        /*filter_width=*/2, /*filter_height=*/2,
                        /*output=*/{TensorType_FLOAT32, {}});
  m.SetInput({
      0, 6, 2, 4,   //
      3, 2, 10, 7,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3.5, 6.5}));
}

class ConvolutionOpModel : public SingleOpModelWithNNAPI {
 public:
  ConvolutionOpModel(
      const TensorData& input, const TensorData& filter,
      const TensorData& output, int stride_width = 2, int stride_height = 2,
      enum Padding padding = Padding_VALID,
      enum ActivationFunctionType activation = ActivationFunctionType_NONE,
      int dilation_width_factor = 1, int dilation_height_factor = 1)
      : input_type_(input.type), filter_type_(filter.type) {
    input_ = AddInput(input);
    filter_ = AddInput(filter);

    int bias_size = GetShape(filter_)[0];
    if (input.type == TensorType_FLOAT32) {
      bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
    } else {
      // This is a quantized version. The scale of 'bias' depends on the scales
      // of input and filter. Supposedly this is correctly set during quantized
      // training.
      auto bias_scale = GetScale(input_) * GetScale(filter_);
      TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
      bias_ = AddInput(bias);
    }

    output_ = AddOutput(output);

    SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
                 CreateConv2DOptions(
                     builder_, padding, stride_width, stride_height, activation,
                     dilation_width_factor, dilation_height_factor)
                     .Union());

    BuildInterpreterWithNNAPI(
        {GetShape(input_), GetShape(filter_), GetShape(bias_)});
  }

  void SetInput(std::initializer_list<float> data) {
    SetData(input_, input_type_, data);
  }

  void SetFilter(std::initializer_list<float> data) {
    SetData(filter_, filter_type_, data);
  }

  void SetBias(std::initializer_list<float> data) {
    const auto bias_type =
        (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
    SetData(bias_, bias_type, data);
  }

  std::vector<float> GetOutput() {
    if (input_type_ == TensorType_FLOAT32) {
      return ExtractVector<float>(output_);
    } else {
      return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                                 GetScale(output_), GetZeroPoint(output_));
    }
  }

  std::vector<uint8_t> GetQuantizedOutput() {
    if (input_type_ == TensorType_FLOAT32) {
      return {};  // Not supported.
    } else {
      return ExtractVector<uint8_t>(output_);
    }
  }

 protected:
  int input_;
  int filter_;
  int bias_;
  int output_;

  const TensorType input_type_;
  const TensorType filter_type_;
};

// In this tests we set the input and output scales so that the results
// match exactly the 'non-quantized' version.
TEST(ConvolutionOpTest, SimpleTestQuantized) {
  ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
                       {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
                       {TensorType_UINT8, {}, -127, 128});
  m.SetInput({
      // First batch
      1, 1, 1, 1,  // row = 1
      2, 2, 2, 2,  // row = 2
      // Second batch
      1, 2, 3, 4,  // row = 1
      1, 2, 3, 4,  // row = 2
  });
  m.SetFilter({
      1, 2, 3, 4,    // first 2x2 filter
      -1, 1, -1, 1,  // second 2x2 filter
      -1, -1, 1, 1,  // third 2x2 filter
  });
  m.SetBias({1, 2, 3});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
                                 {
                                     18, 2, 5,  // first batch, left
                                     18, 2, 5,  // first batch, right
                                     17, 4, 3,  // second batch, left
                                     37, 4, 3,  // second batch, right
                                 },
                                 1e-5)));
  // For good  measure, let's also verify the quantized values:
  EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
                                          145, 129, 132,  //
                                          145, 129, 132,  //
                                          144, 131, 130,  //
                                          164, 131, 130,  //
                                      }));
}

TEST(ConvolutionOpTest, SimpleTestQuantizedGrouped) {
  ConvolutionOpModel m({TensorType_UINT8, {2, 2, 2, 2}, -63.5, 64},
                       {TensorType_UINT8, {2, 2, 2, 1}, -63.5, 64},
                       {TensorType_UINT8, {}, -127, 128});
  m.SetInput({
      // First batch
      1, 1, 1, 1,  // row = 1
      2, 2, 2, 2,  // row = 2
      // Second batch
      1, 2, 3, 4,  // row = 1
      1, 2, 3, 4,  // row = 2
  });
  m.SetFilter({
      1, 2, 3, 4,    // first 2x2 filter
      -1, 1, -1, 1,  // second 2x2 filter
  });
  m.SetBias({1, 2});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
                                 {
                                     18, 2,  // first batch
                                     23, 6   // second batch
                                 },
                                 1e-5)));
  // For good  measure, let's also verify the quantized values:
  EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
                                          145, 129,  //
                                          150, 133,  //
                                      }));
}

TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
  ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
                       {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
                       {TensorType_FLOAT32, {}});
  m.SetInput({
      // First batch
      1, 1, 1, 2,  // row = 1
      2, 2, 2, 1,  // row = 2
      // Second batch
      1, 2, 3, 4,  // row = 1
      1, 2, 3, 4,  // row = 2
  });
  m.SetFilter({
      1, 2, 3, 4,  // first 2x2 filter
      0, 1, 0, 1,  // second 2x2 filter
      0, 0, 1, 1,  // third 2x2 filter
  });
  m.SetBias({1, 2, 3});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
                                 {
                                     18, 5, 7,    // first batch, left
                                     16, 5, 6,    // first batch, right
                                     17, 6, 6,    // second batch, left
                                     37, 10, 10,  // second batch, right
                                 },
                                 0.2)));
}

TEST(ConvolutionOpTest, NoActivation) {
  ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
                       {TensorType_FLOAT32, {3, 2, 2, 1}},
                       {TensorType_FLOAT32, {}});

  m.SetInput({
      // First batch
      1, 1, 1, 1,  // row = 1
      2, 2, 2, 2,  // row = 2
      // Second batch
      1, 2, 3, 4,  // row = 1
      1, 2, 3, 4,  // row = 2
  });
  m.SetFilter({
      1, 2, 3, 4,    // first 2x2 filter
      -1, 1, -1, 1,  // second 2x2 filter
      -1, -1, 1, 1,  // third 2x2 filter
  });
  m.SetBias({1, 2, 3});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 18, 2, 5,  // first batch, left
                                 18, 2, 5,  // first batch, right
                                 17, 4, 3,  // second batch, left
                                 37, 4, 3,  // second batch, right
                             }));
}

TEST(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) {
  // output_multiplier = 1.0118
  ConvolutionOpModel quant_op({TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128},
                              {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128},
                              {TensorType_UINT8, {}, -127, 128});
  ConvolutionOpModel float_op({TensorType_FLOAT32, {2, 2, 4, 1}},
                              {TensorType_FLOAT32, {3, 2, 2, 1}},
                              {TensorType_FLOAT32, {}});
  std::initializer_list<float> input = {
      // First batch
      1, 1, 1, 1,  // row = 1
      2, 2, 2, 2,  // row = 2
      // Second batch
      1, 2, 3, 4,  // row = 1
      1, 2, 3, 4,  // row = 2
  };
  std::initializer_list<float> filter = {
      1,  2,  3,  4,  // first 2x2 filter
      -1, 1,  -1, 1,  // second 2x2 filter
      -1, -1, 1,  1,  // third 2x2 filter
  };
  std::initializer_list<float> bias = {1, 2, 3};

  quant_op.SetInput(input);
  quant_op.SetFilter(filter);
  quant_op.SetBias(bias);
  ASSERT_EQ(quant_op.Invoke(), kTfLiteOk);

  float_op.SetInput(input);
  float_op.SetFilter(filter);
  float_op.SetBias(bias);
  ASSERT_EQ(float_op.Invoke(), kTfLiteOk);

  EXPECT_THAT(quant_op.GetOutput(),
              ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}

TEST(ConvolutionOpTest, SimpleTestFloatWithDilation) {
  const int depth = 1;
  const int image_width = 9;
  const int image_height = 9;
  const int image_batch_count = 1;
  const int filter_size = 3;
  const int filter_count = 1;
  const int stride_width = 1;
  const int stride_height = 1;
  const int dilation_width_factor = 3;
  const int dilation_height_factor = 3;
  const Padding padding = Padding_VALID;
  ConvolutionOpModel m(
      {TensorType_FLOAT32,
       {image_batch_count, image_height, image_width, depth}},
      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
      {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
      ActivationFunctionType_NONE, dilation_width_factor,
      dilation_height_factor);

  // The image matrix is:
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // clang-format off
  m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0});
  // clang-format on
  // The filter matrix is:
  // | 1 | 2 | 3 |
  // | 4 | 5 | 6 |
  // | 7 | 8 | 9 |
  m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
  // Zero bias for this test.
  m.SetBias({0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  // Since the dilation rate is 3 this will reduce the size of the output from
  // 10x10 to 3x3 of all 5s. Specifically:
  // | 5 | 5 | 5 |
  // | 5 | 5 | 5 |
  // | 5 | 5 | 5 |
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}

class QuantizedConvolutionOpModel : public ConvolutionOpModel {
 public:
  using ConvolutionOpModel::ConvolutionOpModel;

  void SetInput(std::initializer_list<float> data) {
    QuantizeAndPopulate<uint8_t>(input_, data);
  }

  void SetFilter(std::initializer_list<float> data) {
    QuantizeAndPopulate<uint8_t>(filter_, data);
  }

  void SetBias(std::initializer_list<float> data) {
    QuantizeAndPopulate<int32_t>(bias_, data);
  }

  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
  std::vector<float> GetDequantizedOutput() {
    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                               GetScale(output_), GetZeroPoint(output_));
  }
};

TEST(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
  const int depth = 1;
  const int image_width = 9;
  const int image_height = 9;
  const int image_batch_count = 1;
  const int filter_size = 3;
  const int filter_count = 1;
  const int stride_width = 1;
  const int stride_height = 1;
  const int dilation_width_factor = 3;
  const int dilation_height_factor = 3;
  const Padding padding = Padding_VALID;
  ConvolutionOpModel m({TensorType_UINT8,
                        {image_batch_count, image_height, image_width, depth},
                        0,
                        127.5},
                       {TensorType_UINT8,
                        {depth, filter_size, filter_size, filter_count},
                        0,
                        127.5},
                       {TensorType_UINT8, {}, 0, 255}, stride_width,
                       stride_height, padding, ActivationFunctionType_NONE,
                       dilation_width_factor, dilation_height_factor);

  // The image matrix is:
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
  // clang-format off
  m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 1, 1, 1, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0});
  // clang-format on
  // The filter matrix is:
  // | 1 | 2 | 3 |
  // | 4 | 5 | 6 |
  // | 7 | 8 | 9 |
  m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
  // Zero bias for this test.
  m.SetBias({0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  // Since the dilation rate is 3 this will reduce the size of the output from
  // 10x10 to 3x3 of all 5s. Specifically:
  // | 5 | 5 | 5 |
  // | 5 | 5 | 5 |
  // | 5 | 5 | 5 |
  EXPECT_THAT(m.GetQuantizedOutput(),
              ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}

class PerChannelQuantizedConvolutionWithConstantFilterOpModel
    : public SingleOpModelWithNNAPI {
 public:
  PerChannelQuantizedConvolutionWithConstantFilterOpModel(
      const TensorData& input, const TensorData& filter,
      std::initializer_list<int8_t> filter_data,
      std::initializer_list<int32_t> bias_data, const TensorData& output,
      int stride_width = 2, int stride_height = 2,
      enum Padding padding = Padding_VALID,
      enum ActivationFunctionType activation = ActivationFunctionType_NONE,
      int dilation_width_factor = 1, int dilation_height_factor = 1)
      : input_type_(input.type), filter_type_(filter.type) {
    CHECK(filter.per_channel_quantization);
    input_ = AddInput(input);
    filter_ = AddConstInput(filter, filter_data);

    const int bias_size = GetShape(filter_)[0];
    const int num_channels = filter.per_channel_quantization_scales.size();
    const std::vector<int64_t> bias_offsets(num_channels, 0);
    std::vector<float> bias_scales(num_channels);
    for (int i = 0; i < num_channels; i++) {
      bias_scales[i] = input.scale * filter.per_channel_quantization_scales[i];
    }
    const TensorData bias{TensorType_INT32,
                          {bias_size},
                          /*min=*/0,
                          /*max=*/0,
                          /*scale=*/0,
                          /*zero_point=*/0,
                          /*per_channel_quantization=*/true,
                          /*per_channel_quantization_scales=*/bias_scales,
                          /*per_channel_quantization_offsets=*/bias_offsets,
                          /*channel_index==*/0};
    bias_ = AddConstInput(bias, bias_data);

    output_ = AddOutput(output);

    SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
                 CreateConv2DOptions(
                     builder_, padding, stride_width, stride_height, activation,
                     dilation_width_factor, dilation_height_factor)
                     .Union());

    BuildInterpreterWithNNAPI(
        {GetShape(input_), GetShape(filter_), GetShape(bias_)});
  }

  void SetInput(std::initializer_list<float> data) {
    QuantizeAndPopulate<int8_t>(input_, data);
  }

  std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }

 protected:
  int input_;
  int filter_;
  int bias_;
  int output_;

  const TensorType input_type_;
  const TensorType filter_type_;
};

TEST(ConvolutionOpTest, SimplePerChannelTest) {
  PerChannelQuantizedConvolutionWithConstantFilterOpModel m(
      {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
      {TensorType_INT8,
       // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
       {2, 2, 2, 2},
       /*min=*/0,
       /*max=*/0,
       /*scale=*/0,
       /*zero_point=*/0,
       /*per_channel_quantization=*/true,
       /*per_channel_quantization_scales=*/{1, 2},
       /*per_channel_quantization_offsets=*/{0, 0},
       /*channel_index=*/0},
      /*filter_data=*/
      {
          // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
          1, 2,  // out channel = 0, y = 0, x = 0
          3, 4,  // out channel = 0, y = 0, x = 1
          3, 4,  // out channel = 0, y = 1, x = 0
          5, 6,  // out channel = 0, y = 1, x = 1
          4, 4,  // out channel = 1, y = 0, x = 0
          3, 3,  // out channel = 1, y = 0, x = 1
          2, 2,  // out channel = 1, y = 1, x = 0
          1, 1,  // out channel = 1, y = 1, x = 1
      },
      /*bias_data=*/{6, -2}, {TensorType_INT8, {}, -63.5, 64, 0.5, -1},
      /*stride_width=*/1, /*stride_height=*/1);
  m.SetInput({
      // [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
      3, 2,    // batch = 0, y = 0, x = 0
      1, -1,   // batch = 0, y = 0, x = 1
      -2, -3,  // batch = 0, y = 0, x = 2
      4, 3,    // batch = 0, y = 1, x = 0
      2, -2,   // batch = 0, y = 1, x = 1
      -3, -4,  // batch = 0, y = 1, x = 2
  });

  // Invoke and verify output.
  // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              testing::Pointwise(QuantizedNear(), {61, 127, -115, -93}));
}

class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
 public:
  DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
                              const TensorData& output)
      : input_type_(input.type) {
    input_ = AddInput(input);
    filter_ = AddInput(filter);

    int bias_size = GetShape(filter_)[3];
    if (input.type == TensorType_FLOAT32) {
      bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
    } else {
      // This is a quantized version. The scale of 'bias' depends on the scales
      // of input and filter. Supposedly this is correctly set during quantized
      // training.
      auto bias_scale = GetScale(input_) * GetScale(filter_);
      TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
      bias_ = AddInput(bias);
    }

    output_ = AddOutput(output);

    int input_depth = GetShape(input_)[3];
    int output_depth = GetShape(filter_)[3];
    int depth_mul = output_depth / input_depth;

    SetBuiltinOp(
        BuiltinOperator_DEPTHWISE_CONV_2D,
        BuiltinOptions_DepthwiseConv2DOptions,
        CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
                                     ActivationFunctionType_NONE)
            .Union());

    BuildInterpreterWithNNAPI(
        {GetShape(input_), GetShape(filter_), GetShape(bias_)});
  }

  void SetInput(std::initializer_list<float> data) {
    SetData(input_, input_type_, data);
  }

  void SetFilter(std::initializer_list<float> data) {
    SetData(filter_, input_type_, data);
  }

  void SetBias(std::initializer_list<float> data) {
    const auto bias_type =
        (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
    SetData(bias_, bias_type, data);
  }

  std::vector<float> GetOutput() {
    if (input_type_ == TensorType_FLOAT32) {
      return ExtractVector<float>(output_);
    } else {
      return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                                 GetScale(output_), GetZeroPoint(output_));
    }
  }

 protected:
  int input_;
  int filter_;
  int bias_;
  int output_;

  const TensorType input_type_;
};

TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
  DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
                                {TensorType_FLOAT32, {1, 2, 2, 4}},
                                {TensorType_FLOAT32, {}});

  m.SetInput({
      1, 2, 7, 8,    // column 1
      3, 4, 9, 10,   // column 2
      5, 6, 11, 12,  // column 3
  });
  m.SetFilter({
      1, 2, 3, 4,        //
      -9, 10, -11, 12,   //
      5, 6, 7, 8,        //
      13, -14, 15, -16,  //
  });
  m.SetBias({1, 2, 3, 4});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 71, -34, 99, -20,  //
                                 91, -26, 127, -4,  //
                             }));
}

TEST(QuantizedDepthwiseConv2DTest, FilterMultiplierGreaterThan1) {
  DepthwiseConvolutionOpModel quant_op(
      {TensorType_UINT8, {1, 3, 2, 2}, -128.5, 128},
      {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
      {TensorType_UINT8, {}, -127, 128});
  DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
                                       {TensorType_FLOAT32, {1, 2, 2, 4}},
                                       {TensorType_FLOAT32, {}});

  std::initializer_list<float> input = {
      1, 2, 7,  8,   // column 1
      3, 4, 9,  10,  // column 2
      5, 6, 11, 12,  // column 3
  };
  std::initializer_list<float> filter = {
      1,  2,   3,   4,    //
      -9, 10,  -11, 12,   //
      5,  6,   7,   8,    //
      13, -14, 15,  -16,  //
  };
  std::initializer_list<float> bias = {1, 2, 3, 4};

  quant_op.SetInput(input);
  quant_op.SetFilter(filter);
  quant_op.SetBias(bias);
  ASSERT_EQ(quant_op.Invoke(), kTfLiteOk);

  float_op.SetInput(input);
  float_op.SetFilter(filter);
  float_op.SetBias(bias);
  ASSERT_EQ(float_op.Invoke(), kTfLiteOk);

  EXPECT_THAT(quant_op.GetOutput(),
              ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}

class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
 public:
  FullyConnectedOpModel(
      const TensorData& input, const TensorData& weights,
      const TensorData& output,
      enum ActivationFunctionType activation = ActivationFunctionType_NONE)
      : input_type_(input.type), weights_type_(weights.type) {
    input_ = AddInput(input);
    weights_ = AddInput(weights);

    const int units = weights.shape[0];
    if (input.type == TensorType_FLOAT32) {
      bias_ = AddInput({TensorType_FLOAT32, {units}});
    } else {
      // This is a quantized version. The scale of 'bias' depends on the scales
      // of input and filter. Supposedly this is correctly set during quantized
      // training.
      auto bias_scale = GetScale(input_) * GetScale(weights_);
      TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
      bias_ = AddInput(bias);
    }

    output_ = AddOutput(output);

    SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
                 BuiltinOptions_FullyConnectedOptions,
                 CreateFullyConnectedOptions(builder_, activation).Union());
    BuildInterpreterWithNNAPI(
        {GetShape(input_), GetShape(weights_), GetShape(bias_)});
  }

  void SetInput(std::initializer_list<float> data) {
    SetData(input_, input_type_, data);
  }

  void SetWeights(std::initializer_list<float> data) {
    SetData(weights_, weights_type_, data);
  }

  void SetBias(std::initializer_list<float> data) {
    const auto bias_type =
        (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
    SetData(bias_, bias_type, data);
  }

  std::vector<float> GetOutput() {
    if (input_type_ == TensorType_FLOAT32) {
      return ExtractVector<float>(output_);
    } else {
      return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                                 GetScale(output_), GetZeroPoint(output_));
    }
  }

 protected:
  int input_;
  int weights_;
  int bias_;
  int output_;

  const TensorType input_type_;
  const TensorType weights_type_;
};

TEST(FullyConnectedOpTest, SimpleTest) {
  FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
                          /*weights=*/{TensorType_FLOAT32, {3, 10}},
                          /*output=*/{TensorType_FLOAT32});
  m.SetWeights({
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
  });
  m.SetBias({1, 2, 3});

  m.SetInput({
      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}

TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
  FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
                          /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
                          /*output=*/{TensorType_FLOAT32});
  m.SetWeights({
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
  });
  m.SetBias({1, 2, 3});

  m.SetInput({
      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(),
              ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
}

TEST(FullyConnectedOpTest, QuantizedOutputMultiplierGreaterThan1) {
  // real_multiplier = 2.
  FullyConnectedOpModel m(
      /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
      /*weights=*/{TensorType_UINT8, {3, 10}, -127, 128},
      /*output=*/{TensorType_UINT8, {}, -63.5, 64});

  m.SetWeights({
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
  });
  m.SetBias({1, 2, 3});

  m.SetInput({
      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
                                 24, 25, 26,  // first batch
                                 58, 59, 60,  // second batch
                             })));
}

class SoftmaxOpModel : public SingleOpModelWithNNAPI {
 public:
  SoftmaxOpModel(const TensorData& input, float beta) {
    input_ = AddInput(input);
    output_ = AddOutput(input);
    SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
                 CreateSoftmaxOptions(builder_, beta).Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor(input_, data);
  }

  void SetInput(int offset, float* begin, float* end) {
    PopulateTensor(input_, offset, begin, end);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 private:
  int input_;
  int output_;
};

TEST(SoftmaxOpTest, SimpleTest) {
  SoftmaxOpModel m({TensorType_FLOAT32, {2, 5}}, /*beta=*/1.0);
  m.SetInput({
      1.0, 2.0, 3.0, 4.0, 5.0,       // b = 0
      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(
      m.GetOutput(),
      NnapiArrayFloatNear({0.011656231, 0.031684921, 0.086128544, 0.234121657,
                           0.636408647, 0.636408647, 0.234121657, 0.086128544,
                           0.031684921, 0.011656231}));
}

TEST(SoftmaxOpTest, Beta2) {
  SoftmaxOpModel m({TensorType_FLOAT32, {1, 5}}, /*beta=*/2.0);
  m.SetInput({
      1.0, 2.0, 3.0, 4.0, 5.0,  // b = 0
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({0.000290076, 0.002143387, 0.015837606,
                                   0.117024957, 0.864703974}));
}

TEST(SoftmaxOpTest, 3dInput) {
  SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 5}}, /*beta=*/1.0);
  m.SetInput({
      1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
      5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
      -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(
      m.GetOutput(),
      NnapiArrayFloatNear(
          {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
           0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
           0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
           0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
}

TEST(SoftmaxOpTest, 4dInput) {
  SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 1, 5}}, /*beta=*/1.0);
  m.SetInput({
      1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
      -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
      5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
      -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
  });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(
      m.GetOutput(),
      NnapiArrayFloatNear(
          {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
           0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
           0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
           0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
}

class ReshapeOpModel : public SingleOpModelWithNNAPI {
 public:
  ReshapeOpModel(std::initializer_list<int> input_shape,
                 std::initializer_list<int> new_shape) {
    input_ = AddInput(TensorType_FLOAT32);
    new_shape_ = AddConstInput<int>(TensorType_INT32, new_shape,
                                    {static_cast<int>(new_shape.size())});
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(
        BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
        CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
            .Union());
    BuildInterpreterWithNNAPI(
        {input_shape, {static_cast<int>(new_shape.size())}});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor<float>(input_, data);
  }
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int new_shape_;
  int output_;
};

TEST(NNAPIDelegate, ReshapeSimpleTest) {
  ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3, 4, 5, 6, 7, 8}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}

class SqueezeOpModel : public SingleOpModelWithNNAPI {
 public:
  SqueezeOpModel(const TensorData& input, const TensorData& output,
                 std::initializer_list<int> axis) {
    input_ = AddInput(input);
    output_ = AddOutput(output);
    SetBuiltinOp(
        BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
        CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
            .Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor<float>(input_, data);
  }
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int new_shape_;
  int output_;
};

// TODO(b/215935381): Enable after resolving issues with flakiness.
TEST(NNAPIDelegate, DISABLED_SqueezeSimpleTest) {
  std::initializer_list<float> data = {
      1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
      13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
  SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
                   {});
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
  EXPECT_THAT(
      m.GetOutput(),
      NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                           9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
                           17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}

TEST(NNAPIDelegate, SqueezeWithAxisTest) {
  std::initializer_list<float> data = {
      1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
      13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
  SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
                   {2});
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
  EXPECT_THAT(
      m.GetOutput(),
      NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                           9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
                           17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}

class L2NormOpModel : public SingleOpModelWithNNAPI {
 public:
  L2NormOpModel(const TensorData& input, const TensorData& output,
                ActivationFunctionType activation_type) {
    input_ = AddInput(input);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
                 CreateL2NormOptions(builder_, activation_type).Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor<float>(input_, data);
  }
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int new_shape_;
  int output_;
};

TEST(NNAPIDelegate, L2NormSimpleTest) {
  std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
  L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
                  {TensorType_FLOAT32, {1, 1, 1, 6}},
                  ActivationFunctionType_NONE);
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}

class TransposeSimpleModel : public SingleOpModelWithNNAPI {
 public:
  TransposeSimpleModel(std::initializer_list<int> input_shape,
                       std::initializer_list<int> perm_shape,
                       std::initializer_list<int> perm) {
    input_ = AddInput(TensorType_FLOAT32);
    perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
                 CreateTransposeOptions(builder_).Union());
    BuildInterpreterWithNNAPI({input_shape, perm_shape});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor<float>(input_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int perm_;
  int output_;
};

TEST(NNAPIDelegate, TransposeSimpleTest) {
  TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
  m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
              12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear(
                                 {0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
                                  2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
}

class ElementwiseOpBaseModel : public SingleOpModelWithNNAPI {
 public:
  int input() const { return input_; }
  int output() const { return output_; }

 protected:
  int input_;
  int output_;
};

class ElementwiseOpFloatModel : public ElementwiseOpBaseModel {
 public:
  ElementwiseOpFloatModel(BuiltinOperator op,
                          std::initializer_list<int> input_shape) {
    input_ = AddInput(TensorType_FLOAT32);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(op, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({input_shape});
  }
};

TEST(Elementwise, Abs) {
  ElementwiseOpFloatModel m(BuiltinOperator_ABS, {1, 2, 4, 1});
  m.PopulateTensor<float>(m.input(), {
                                         0.f, -6.2f, 2.f, 4.f,  //
                                         3.f, -2.f, 10.f, 1.f,  //
                                     });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.ExtractVector<float>(m.output()), NnapiArrayFloatNear({
                                                      0.f, 6.2f, 2.f, 4.f,  //
                                                      3.f, 2.f, 10.f, 1.f,  //
                                                  }));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 4, 1}));
}

TEST(Elementwise, Exp) {
  ElementwiseOpFloatModel m(BuiltinOperator_EXP, {3, 1, 2});
  m.PopulateTensor<float>(m.input(), {1.0, 0.0, -1.0, 1.0, 1.0, -1.0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(
      m.ExtractVector<float>(m.output()),
      NnapiArrayFloatNear({2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({3, 1, 2}));
}

TEST(Elementwise, Log) {
  ElementwiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
  m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.ExtractVector<float>(m.output()),
              NnapiArrayFloatNear({0, 1.14473, 0, 0}));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

TEST(Elementwise, Rsqrt) {
  ElementwiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
  m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.ExtractVector<float>(m.output()),
              NnapiArrayFloatNear({1, 0.7071, 0.5, 0.33333}));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

TEST(Elementwise, Sin) {
  ElementwiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
  m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.ExtractVector<float>(m.output()),
              NnapiArrayFloatNear({0, 0, 0, 0.84147}));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

TEST(Elementwise, Sqrt) {
  ElementwiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
  m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.ExtractVector<float>(m.output()),
              NnapiArrayFloatNear({0, 1, 1.41421, 2}));
  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

class FloatSubOpModel : public SingleOpModelWithNNAPI {
 public:
  FloatSubOpModel(const TensorData& input1, const TensorData& input2,
                  const TensorData& output,
                  ActivationFunctionType activation_type) {
    input1_ = AddInput(input1);
    input2_ = AddInput(input2);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions,
                 CreateMulOptions(builder_, activation_type).Union());
    BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
  }

  int input1() { return input1_; }
  int input2() { return input2_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input1_;
  int input2_;
  int output_;
};

TEST(NNAPIDelegate, SubWithNoActivation) {
  FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-2.1, 0.0, 0.4, 0.3}));
}

class FloatDivOpModel : public SingleOpModelWithNNAPI {
 public:
  FloatDivOpModel(const TensorData& input1, const TensorData& input2,
                  const TensorData& output,
                  ActivationFunctionType activation_type) {
    input1_ = AddInput(input1);
    input2_ = AddInput(input2);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions,
                 CreateMulOptions(builder_, activation_type).Union());
    BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
  }

  int input1() { return input1_; }
  int input2() { return input2_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input1_;
  int input2_;
  int output_;
};

TEST(NNAPIDelegate, DivWithNoActivation) {
  FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {1, 2, 2, 1}},
                    {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
  m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.8, 0.8});
  m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.4, 0.2});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-20, 1, 2, 4}));
}

class BaseConcatenationOpModel : public SingleOpModelWithNNAPI {
 public:
  BaseConcatenationOpModel() {}
  BaseConcatenationOpModel(const TensorData& input_template, int axis,
                           int num_inputs) {
    std::vector<std::vector<int>> all_input_shapes;
    for (int i = 0; i < num_inputs; ++i) {
      all_input_shapes.push_back(input_template.shape);
      AddInput(input_template);
    }
    output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
                         input_template.max});
    SetBuiltinOp(
        BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
        CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
            .Union());
    BuildInterpreterWithNNAPI(all_input_shapes);
  }

 protected:
  int output_;
};

class ConcatenationOpModel : public BaseConcatenationOpModel {
 public:
  using BaseConcatenationOpModel::BaseConcatenationOpModel;
  void SetInput(int index, std::initializer_list<float> data) {
    PopulateTensor(index, data);
  }
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};

TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) {
  ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
                          /*num_inputs=*/1);
  m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
  ASSERT_EQ(m0.Invoke(), kTfLiteOk);
  EXPECT_THAT(m0.GetOutput(), NnapiArrayFloatNear({1, 3, 4, 7}));
}

TEST(NNAPIDelegate, ConcatenationFourInputs) {
  ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
                          /*num_inputs=*/4);
  m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
  m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
  m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
  m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
  ASSERT_EQ(m0.Invoke(), kTfLiteOk);
  EXPECT_THAT(m0.GetOutput(),
              NnapiArrayFloatNear({
                  1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
                  4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
              }));
}

class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
 public:
  using BaseConcatenationOpModel::BaseConcatenationOpModel;
  QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
                                int axis, int num_inputs,
                                const TensorData& output_template) {
    std::vector<std::vector<int>> all_input_shapes;
    CHECK_EQ(input_template.size(), num_inputs);
    for (int i = 0; i < num_inputs; ++i) {
      all_input_shapes.push_back(input_template[i].shape);
      AddInput(input_template[i]);
    }
    output_ = AddOutput({output_template.type, /*shape=*/{},
                         output_template.min, output_template.max});
    SetBuiltinOp(
        BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
        CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
            .Union());
    BuildInterpreterWithNNAPI(all_input_shapes);
  }
  void SetInput(int index, std::initializer_list<float> data) {
    QuantizeAndPopulate<uint8_t>(index, data);
  }
  std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
  std::vector<float> GetDequantizedOutput() {
    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                               GetScale(output_), GetZeroPoint(output_));
  }
};

TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) {
  QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
                                   /*axis=*/2,
                                   /*num_inputs=*/4);

  m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
  m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
  m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
  m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
  ASSERT_EQ(m0.Invoke(), kTfLiteOk);
  EXPECT_THAT(m0.GetDequantizedOutput(),
              ElementsAreArray(ArrayFloatNear({
                  1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
                  4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
              })));
  EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
                                  137, 157, 138, 158, 139, 159, 140, 160,  //
                                  167, 197, 168, 198, 169, 199, 170, 200,  //
                              }));
}

TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) {
  QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
                                    {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
                                    {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
                                    {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
                                   /*axis=*/2, /*num_inputs=*/4,
                                   {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});

  m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
  m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
  m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
  m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
  ASSERT_EQ(m0.Invoke(), kTfLiteOk);
  EXPECT_THAT(m0.GetDequantizedOutput(),
              ElementsAreArray(ArrayFloatNear({
                  1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
                  4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
              })));
  EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
                                  137, 157, 138, 158, 139, 159, 140, 160,  //
                                  167, 197, 168, 198, 169, 199, 170, 200,  //
                              }));
}

class DequantizeOpModel : public SingleOpModelWithNNAPI {
 public:
  DequantizeOpModel(TensorType inputType, std::initializer_list<int> shape,
                    float min, float max) {
    input_ = AddInput({inputType, shape, min, max});
    output_ = AddOutput({TensorType_FLOAT32, shape});
    SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
                 CreateDequantizeOptions(builder_).Union());

    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  template <typename T>
  void SetInput(std::initializer_list<T> data) {
    PopulateTensor(input_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 private:
  int input_;
  int output_;
};

TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) {
  DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64);

  m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              ElementsAreArray(ArrayFloatNear(
                  {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
}

TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) {
  // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8
  DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5);

  m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              ElementsAreArray(ArrayFloatNear(
                  {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5})));
}

class FloorOpModel : public SingleOpModelWithNNAPI {
 public:
  FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
    input_ = AddInput(TensorType_FLOAT32);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({
        input_shape,
    });
  }

  int input() { return input_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int output_;
};

TEST(NNAPIDelegate, FloorSingleDim) {
  FloorOpModel model({2}, TensorType_FLOAT32);
  model.PopulateTensor<float>(model.input(), {8.5, 0.0});
  ASSERT_EQ(model.Invoke(), kTfLiteOk);
  EXPECT_THAT(model.GetOutput(), NnapiArrayFloatNear({8, 0}));
  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
}

TEST(NNAPIDelegate, FloorMultiDims) {
  FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
  model.PopulateTensor<float>(model.input(), {
                                                 0.0001,
                                                 8.0001,
                                                 0.9999,
                                                 9.9999,
                                                 0.5,
                                                 -0.0001,
                                                 -8.0001,
                                                 -0.9999,
                                                 -9.9999,
                                                 -0.5,
                                             });
  ASSERT_EQ(model.Invoke(), kTfLiteOk);
  EXPECT_THAT(model.GetOutput(),
              NnapiArrayFloatNear({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
}

class LocalResponseNormOpModel : public SingleOpModelWithNNAPI {
 public:
  LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
                           float bias, float alpha, float beta) {
    input_ = AddInput(TensorType_FLOAT32);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
                 BuiltinOptions_LocalResponseNormalizationOptions,
                 CreateLocalResponseNormalizationOptions(builder_, radius, bias,
                                                         alpha, beta)
                     .Union());
    BuildInterpreterWithNNAPI({input_shape});
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor(input_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 private:
  int input_;
  int output_;
};

TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) {
  LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
                             /*alpha=*/1.0, /*beta=*/0.5);
  m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  // The result is every input divided by 2.
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}

TEST(NNAPIDelegate, LocalResponseNormWithAlpha) {
  LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
                             /*alpha=*/4.0, /*beta=*/0.5);
  m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  // The result is every input divided by 3.
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}));
}

TEST(NNAPIDelegate, LocalResponseNormWithBias) {
  LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
                             /*alpha=*/4.0, /*beta=*/0.5);
  m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  // The result is every input divided by 5.
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}));
}

TEST(NNAPIDelegate, LocalResponseNormSmallRadius) {
  LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
                             /*alpha=*/4.0, /*beta=*/0.5);
  m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({-0.264926, 0.125109, 0.140112, 0.267261,
                                   -0.161788, 0.0244266}));
}

class LSHProjectionOpModel : public SingleOpModelWithNNAPI {
 public:
  LSHProjectionOpModel(LSHProjectionType type,
                       std::initializer_list<int> hash_shape,
                       std::initializer_list<int> input_shape,
                       std::initializer_list<int> weight_shape) {
    hash_ = AddInput(TensorType_FLOAT32);
    input_ = AddInput(TensorType_INT32);
    if (weight_shape.size() > 0) {
      weight_ = AddInput(TensorType_FLOAT32);
    }
    output_ = AddOutput(TensorType_INT32);

    SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
                 BuiltinOptions_LSHProjectionOptions,
                 CreateLSHProjectionOptions(builder_, type).Union());
    if (weight_shape.size() > 0) {
      BuildInterpreterWithNNAPI({hash_shape, input_shape, weight_shape});
    } else {
      BuildInterpreterWithNNAPI({hash_shape, input_shape});
    }

    output_size_ = 1;
    for (int i : hash_shape) {
      output_size_ *= i;
      if (type == LSHProjectionType_SPARSE) {
        break;
      }
    }
  }
  void SetInput(std::initializer_list<int> data) {
    PopulateTensor(input_, data);
  }

  void SetHash(std::initializer_list<float> data) {
    PopulateTensor(hash_, data);
  }

  void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }

  std::vector<int> GetOutput() { return ExtractVector<int>(output_); }

 private:
  int input_;
  int hash_;
  int weight_;
  int output_;

  int output_size_;
};

TEST(NNAPIDelegate, LSHProjectionDense1DInputs) {
  LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});

  m.SetInput({12345, 54321, 67890, 9876, -12345678});
  m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
  m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  // Hash returns differently on machines with different endianness
  EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 1, 1, 1, 0));
#else
  EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
#endif
}

TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
  LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});

  m.SetInput({12345, 54321, 67890, 9876, -12345678});
  m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  // Hash returns differently on machines with different endianness
  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
#else
  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
#endif
}

TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
  LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});

  m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
              9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
  m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
  m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
    __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  // Hash returns differently on machines with different endianness
  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
#else
  EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
#endif
}

class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
 public:
  // Most activations don't take any options, so this constructor works for
  // them.
  BaseActivationsOpModel(BuiltinOperator type, const TensorData& input) {
    input_ = AddInput(input);
    if (input.type == TensorType_UINT8) {
      output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
    } else {
      output_ = AddOutput({input.type, {}});
    }
    SetBuiltinOp(type, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
                         const TensorData& output) {
    input_ = AddInput(input);
    output_ = AddOutput(output);
    SetBuiltinOp(type, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

 protected:
  int input_;
  int output_;
};

class FloatActivationsOpModel : public BaseActivationsOpModel {
 public:
  using BaseActivationsOpModel::BaseActivationsOpModel;

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor(input_, data);
  }
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};

const float kQuantizedTolerance = 2 * (1. / 256);

class QuantizedActivationsOpModel : public BaseActivationsOpModel {
 public:
  using BaseActivationsOpModel::BaseActivationsOpModel;

  template <typename T>
  void SetInput(std::initializer_list<float> data) {
    QuantizeAndPopulate<T>(input_, data);
  }
  template <typename T>

  std::vector<T> GetOutput() {
    return ExtractVector<T>(output_);
  }
  template <typename T>
  std::vector<float> GetDequantizedOutput() {
    return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
                         GetZeroPoint(output_));
  }
};

TEST(NNAPIDelegate, Relu) {
  FloatActivationsOpModel m(BuiltinOperator_RELU,
                            /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
  m.SetInput({
      0, -6, 2, 4,   //
      3, -2, 10, 1,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0, 0, 2, 4,   //
                                 3, 0, 10, 1,  //
                             }));
}

TEST(NNAPIDelegate, Relu1) {
  FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
                            /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
  m.SetInput({
      0.0, -0.6, 0.2, -0.4,  //
      0.3, -2.0, 1.1, -0.1,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0.0, -0.6, 0.2, -0.4,  //
                                 0.3, -1.0, 1.0, -0.1,  //
                             }));
}

TEST(NNAPIDelegate, Relu6) {
  FloatActivationsOpModel m(BuiltinOperator_RELU6,
                            /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
  m.SetInput({
      0, -6, 2, 4,   //
      3, -2, 10, 1,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0, 0, 2, 4,  //
                                 3, 0, 6, 1,  //
                             }));
}

TEST(NNAPIDelegate, LogisticFloat) {
  FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
                            /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
  m.SetInput({
      0, -6, 2, 4,   //
      3, -2, 10, 1,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0.5, 0.002473, 0.880797, 0.982014,       //
                                 0.952574, 0.119203, 0.999955, 0.731059,  //
                             }));
}

TEST(NNAPIDelegate, LogisticQuantized) {
  QuantizedActivationsOpModel m(
      BuiltinOperator_LOGISTIC,
      /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
  m.SetInput<uint8_t>({
      0, -6, 2, 4,   //
      3, -2, 10, 1,  //
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
              ElementsAreArray(ArrayFloatNear(
                  {
                      0.5, 0.002473, 0.880797, 0.982014,       //
                      0.952574, 0.119203, 0.999955, 0.731059,  //
                  },
                  kQuantizedTolerance)));
  EXPECT_THAT(m.GetOutput<uint8_t>(),
              testing::Pointwise(QuantizedNear(),
                                 {128, 1, 227, 251, 244, 32, 255, 188}));
}

class ResizeBilinearOpModel : public SingleOpModelWithNNAPI {
 public:
  ResizeBilinearOpModel(const TensorData& input,
                        std::initializer_list<int> size_data) {
    bool const_size = size_data.size() != 0;
    input_ = AddInput(input);
    if (const_size) {
      size_ = AddConstInput(TensorType_INT32, size_data, {2});
    } else {
      size_ = AddInput({TensorType_INT32, {2}});
    }
    output_ = AddOutput(input.type);
    SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
                 BuiltinOptions_ResizeBilinearOptions,
                 CreateResizeBilinearOptions(builder_).Union());
    if (const_size) {
      BuildInterpreterWithNNAPI({GetShape(input_)});
    } else {
      BuildInterpreterWithNNAPI({GetShape(input_), GetShape(size_)});
    }
  }

  template <typename T>
  void SetInput(std::initializer_list<T> data) {
    PopulateTensor(input_, data);
  }
  void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }

  template <typename T>
  std::vector<T> GetOutput() {
    return ExtractVector<T>(output_);
  }

 private:
  int input_;
  int size_;
  int output_;
};

TEST(ResizeBilinear, Horizontal) {
  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
  m.SetInput<float>({3, 6});
  m.SetSize({1, 3});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
}

TEST(ResizeBilinear, HorizontalConstant) {
  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
  const_m.SetInput<float>({3, 6});
  ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
  EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
}

TEST(ResizeBilinear, Vertical) {
  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
  m.SetInput<float>({3, 9});
  m.SetSize({3, 1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
}

TEST(ResizeBilinear, VerticalConstant) {
  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
  const_m.SetInput<float>({3, 9});
  ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
  EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
}

TEST(ResizeBilinear, TwoDimensional) {
  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
  m.SetInput<float>({
      3, 6,  //
      9, 12  //
  });
  m.SetSize({3, 3});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({
                                        3, 5, 6,    //
                                        7, 9, 10,   //
                                        9, 11, 12,  //
                                    }));
}

TEST(ResizeBilinear, TwoDimensionalConstant) {
  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
  const_m.SetInput<float>({
      3, 6,  //
      9, 12  //
  });
  ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
  EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({
                                              3, 5, 6,    //
                                              7, 9, 10,   //
                                              9, 11, 12,  //
                                          }));
}

template <typename T>
class PadOpModel : public SingleOpModelWithNNAPI {
 public:
  void SetInput(std::initializer_list<T> data) {
    PopulateTensor<T>(input_, data);
  }

  template <typename QuantizedInputOutput>
  void SetQuantizedInput(std::initializer_list<float> data) {
    QuantizeAndPopulate<QuantizedInputOutput>(input_, data);
  }

  template <typename QuantizedInputOutput>
  void SetQuantizedPadValue(float data) {
    QuantizeAndPopulate<QuantizedInputOutput>(constant_values_, {data});
  }

  void SetPaddings(std::initializer_list<int> paddings) {
    PopulateTensor<int>(paddings_, paddings);
  }

  std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

  template <typename QuantizedInputOutput>
  std::vector<float> GetDequantizedOutput() {
    return Dequantize<QuantizedInputOutput>(
        ExtractVector<QuantizedInputOutput>(output_), GetScale(output_),
        GetZeroPoint(output_));
  }

 protected:
  int input_;
  int output_;
  int paddings_;
  int constant_values_;
};

class PadOpConstModel : public PadOpModel<float> {
 public:
  PadOpConstModel(const TensorData& input,
                  std::initializer_list<int> paddings_shape,
                  std::initializer_list<int> paddings,
                  const TensorData& output) {
    input_ = AddInput(input);
    paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
    output_ = AddOutput(output);

    SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
                 CreatePadOptions(builder_).Union());
    BuildInterpreterWithNNAPI({input.shape});
  }
};

TEST(NNAPIDelegate, PadAdvancedConstTest) {
  PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
                    {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4, 5, 6});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
                                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI {
 public:
  void SetInput(std::initializer_list<float> data) {
    PopulateTensor<float>(input_, data);
  }

  void SetBlockShape(std::initializer_list<int> data) {
    PopulateTensor<int>(block_shape_, data);
  }

  void SetPaddings(std::initializer_list<int> data) {
    PopulateTensor<int>(paddings_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 protected:
  int input_;
  int block_shape_;
  int paddings_;
  int output_;
};

class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
 public:
  SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
                             std::initializer_list<int> block_shape,
                             std::initializer_list<int> paddings) {
    input_ = AddInput(TensorType_FLOAT32);
    block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
    paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
    output_ = AddOutput(TensorType_FLOAT32);

    SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
                 BuiltinOptions_SpaceToBatchNDOptions,
                 CreateSpaceToBatchNDOptions(builder_).Union());
    BuildInterpreterWithNNAPI({input_shape});
  }
};

TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) {
  SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
                                                  7, 13, 15, 6, 8, 14, 16}));
}

TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) {
  SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
                                                  7, 13, 15, 6, 8, 14, 16}));
}

TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) {
  SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7,
                                 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
                             }));
}

TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) {
  SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0,
                                 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0,
                                 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
                             }));
}

template <typename input_type = float,
          TensorType tensor_input_type = TensorType_FLOAT32>
class StridedSliceOpModel : public SingleOpModelWithNNAPI {
 public:
  StridedSliceOpModel(std::initializer_list<int> input_shape,
                      std::initializer_list<int> begin_shape,
                      std::initializer_list<int> begin_data,
                      std::initializer_list<int> end_shape,
                      std::initializer_list<int> end_data,
                      std::initializer_list<int> strides_shape,
                      std::initializer_list<int> strides_data, int begin_mask,
                      int end_mask, int ellipsis_mask, int new_axis_mask,
                      int shrink_axis_mask) {
    input_ = AddInput(tensor_input_type);
    begin_ = AddConstInput(TensorType_INT32, begin_data, begin_shape);
    end_ = AddConstInput(TensorType_INT32, end_data, end_shape);
    strides_ = AddConstInput(TensorType_INT32, strides_data, strides_shape);
    output_ = AddOutput(tensor_input_type);
    SetBuiltinOp(
        BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
        CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
                                  new_axis_mask, shrink_axis_mask)
            .Union());
    BuildInterpreterWithNNAPI(
        {input_shape, begin_shape, end_shape, strides_shape});
  }

  void SetInput(std::initializer_list<input_type> data) {
    PopulateTensor<input_type>(input_, data);
  }

  std::vector<input_type> GetOutput() {
    return ExtractVector<input_type>(output_);
  }
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

 private:
  int input_;
  int begin_;
  int end_;
  int strides_;
  int output_;
};

TEST(StridedSliceOpTest, In1D) {
  StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 0, 0, 0, 0, 0);
  m.SetInput({1, 2, 3, 4});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2, 3}));
}

TEST(StridedSliceOpTest, In1D_BeginMask) {
  StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 1, 0, 0, 0, 0);
  m.SetInput({1, 2, 3, 4});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3}));
}

TEST(StridedSliceOpTest, In2D_Stride2) {
  StridedSliceOpModel<> m({2, 3}, {2}, {0, 0}, {2}, {2, 3}, {2}, {2, 2}, 0, 0,
                          0, 0, 0);
  m.SetInput({1, 2, 3, 4, 5, 6});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3}));
}

TEST(StridedSliceOpTest, In2D_EndMask) {
  StridedSliceOpModel<> m({2, 3}, {2}, {1, 0}, {2}, {2, 2}, {2}, {1, 1}, 0, 2,
                          0, 0, 0);
  m.SetInput({1, 2, 3, 4, 5, 6});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({4, 5, 6}));
}

TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
  StridedSliceOpModel<> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 1}, {3},
                          {1, 1, 1}, 0, 0, 0, 0, 4);
  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 5, 7, 9, 11}));
}

static float rnn_input[] = {
    0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
    0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
    -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
    0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
    0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
    0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
    -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
    -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
    0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
    0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
    0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
    -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
    0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
    -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
    -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
    -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
    0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
    -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
    -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
    0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
    -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
    0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
    0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
    0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
    -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
    0.93455386,   -0.6324693,   -0.083922029};

static float rnn_golden_output[] = {
    0.496726,   0,          0.965996,  0,         0.0584254, 0,
    0,          0.12315,    0,         0,         0.612266,  0.456601,
    0,          0.52286,    1.16099,   0.0291232,

    0,          0,          0.524901,  0,         0,         0,
    0,          1.02116,    0,         1.35762,   0,         0.356909,
    0.436415,   0.0355727,  0,         0,

    0,          0,          0,         0.262335,  0,         0,
    0,          1.33992,    0,         2.9739,    0,         0,
    1.31914,    2.66147,    0,         0,

    0.942568,   0,          0,         0,         0.025507,  0,
    0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
    0.8158,     1.21805,    0.586239,  0.25427,

    1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
    0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
    0,          1.22031,    1.30117,   0.495867,

    0.222187,   0,          0.72725,   0,         0.767003,  0,
    0,          0.147835,   0,         0,         0,         0.608758,
    0.469394,   0.00720298, 0.927537,  0,

    0.856974,   0.424257,   0,         0,         0.937329,  0,
    0,          0,          0.476425,  0,         0.566017,  0.418462,
    0.141911,   0.996214,   1.13063,   0,

    0.967899,   0,          0,         0,         0.0831304, 0,
    0,          1.00378,    0,         0,         0,         1.44818,
    1.01768,    0.943891,   0.502745,  0,

    0.940135,   0,          0,         0,         0,         0,
    0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
    1.30225,    1.59644,    0.70222,   0,

    0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
    0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
    0.0454298,  0.300267,   0.562784,  0.395095,

    0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    0,          0,          0,         0.735363,  0.0759267, 1.91017,
    0.941888,   0,          0,         0,

    0,          0,          1.5909,    0,         0,         0,
    0,          0.5755,     0,         0.184687,  0,         1.56296,
    0.625285,   0,          0,         0,

    0,          0,          0.0857888, 0,         0,         0,
    0,          0.488383,   0.252786,  0,         0,         0,
    1.02817,    1.85665,    0,         0,

    0.00981836, 0,          1.06371,   0,         0,         0,
    0,          0,          0,         0.290445,  0.316406,  0,
    0.304161,   1.25079,    0.0707152, 0,

    0.986264,   0.309201,   0,         0,         0,         0,
    0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    0.524981,   1.92076,    2.07013,   0.333244,

    0.415153,   0.210318,   0,         0,         0,         0,
    0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    0.628881,   3.58099,    1.49974,   0};

static std::initializer_list<float> rnn_weights = {
    0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
    0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
    0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
    -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
    -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
    -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
    -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
    0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
    0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
    0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
    -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
    0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
    -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
    -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
    0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
    0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
    0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
    -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
    0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
    0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
    -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
    0.277308,    0.415818};

static std::initializer_list<float> rnn_recurrent_weights = {
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0.1};

static std::initializer_list<float> rnn_bias = {
    0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
    -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
    0.37197268,  0.61957061,  0.3956964,  -0.37609905};

class RNNOpModel : public SingleOpModelWithNNAPI {
 public:
  RNNOpModel(int batches, int units, int size,
             const TensorType weights = TensorType_FLOAT32,
             const TensorType recurrent_weights = TensorType_FLOAT32)
      : batches_(batches), units_(units), input_size_(size) {
    input_ = AddInput(TensorType_FLOAT32);
    weights_ = AddInput(weights);
    recurrent_weights_ = AddInput(recurrent_weights);
    bias_ = AddInput(TensorType_FLOAT32);
    hidden_state_ = AddVariableInput(TensorType_FLOAT32);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(
        BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
        CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
    BuildInterpreterWithNNAPI({
        {batches_, input_size_},  // input tensor
        {units_, input_size_},    // weights tensor
        {units_, units_},         // recurrent weights tensor
        {units_},                 // bias tensor
        {batches_, units_}        // hidden state tensor
    });
  }

  void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }

  void SetWeights(std::initializer_list<float> f) {
    PopulateTensor(weights_, f);
  }

  void SetRecurrentWeights(std::initializer_list<float> f) {
    PopulateTensor(recurrent_weights_, f);
  }

  void SetInput(std::initializer_list<float> data) {
    PopulateTensor(input_, data);
  }

  void SetInput(int offset, float* begin, float* end) {
    PopulateTensor(input_, offset, begin, end);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

  int input_size() { return input_size_; }
  int num_units() { return units_; }
  int num_batches() { return batches_; }

 protected:
  int input_;
  int weights_;
  int recurrent_weights_;
  int bias_;
  int hidden_state_;
  int output_;

  int batches_;
  int units_;
  int input_size_;
};

TEST(NNAPIDelegate, RnnBlackBoxTest) {
  RNNOpModel rnn(2, 16, 8);
  rnn.SetWeights(rnn_weights);
  rnn.SetBias(rnn_bias);
  rnn.SetRecurrentWeights(rnn_recurrent_weights);

  const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
                                  (rnn.input_size() * rnn.num_batches());

  for (int i = 0; i < input_sequence_size; i++) {
    float* batch_start = rnn_input + i * rnn.input_size();
    float* batch_end = batch_start + rnn.input_size();
    rnn.SetInput(0, batch_start, batch_end);
    rnn.SetInput(rnn.input_size(), batch_start, batch_end);

    ASSERT_EQ(rnn.Invoke(), kTfLiteOk);

    float* golden_start = rnn_golden_output + i * rnn.num_units();
    float* golden_end = golden_start + rnn.num_units();
    std::vector<float> expected;
    expected.insert(expected.end(), golden_start, golden_end);
    expected.insert(expected.end(), golden_start, golden_end);

    EXPECT_THAT(rnn.GetOutput(), NnapiArrayFloatNear(expected));
  }
}

static float svdf_input[] = {
    0.12609188,  -0.46347019, -0.89598465,
    0.35867718,  0.36897406,  0.73463392,

    0.14278367,  -1.64410412, -0.75222826,
    -0.57290924, 0.12729003,  0.7567004,

    0.49837467,  0.19278903,  0.26584083,
    0.17660543,  0.52949083,  -0.77931279,

    -0.11186574, 0.13164264,  -0.05349274,
    -0.72674477, -0.5683046,  0.55900657,

    -0.68892461, 0.37783599,  0.18263303,
    -0.63690937, 0.44483393,  -0.71817774,

    -0.81299269, -0.86831826, 1.43940818,
    -0.95760226, 1.82078898,  0.71135032,

    -1.45006323, -0.82251364, -1.69082689,
    -1.65087092, -1.89238167, 1.54172635,

    0.03966608,  -0.24936394, -0.77526885,
    2.06740379,  -1.51439476, 1.43768692,

    0.11771342,  -0.23761693, -0.65898693,
    0.31088525,  -1.55601168, -0.87661445,

    -0.89477462, 1.67204106,  -0.53235275,
    -0.6230064,  0.29819036,  1.06939757,
};

static float svdf_golden_output_rank_1[] = {
    0.014899,    -0.0517661,  -0.143725,   -0.00271883,
    -0.03004015, 0.09565311,  0.1587342,   0.00784263,

    0.068281,    -0.162217,   -0.152268,   0.00323521,
    0.01582633,  0.03858774,  -0.03001583, -0.02671271,

    -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
    -0.01432795, 0.05524484,  0.1101355,   -0.02382665,

    -0.00623099, -0.077701,   -0.391193,   -0.0136691,
    -0.02333033, 0.02293761,  0.12338032,  0.04326871,

    0.201551,    -0.164607,   -0.179462,   -0.0592739,
    0.01064911,  -0.17503069, 0.07821996,  -0.00224009,

    0.0886511,   -0.0875401,  -0.269283,   0.0281379,
    -0.02282338, 0.09741908,  0.32973239,  0.12281385,

    -0.201174,   -0.586145,   -0.628624,   -0.0330412,
    0.24780814,  -0.39304617, -0.22473189, 0.02589256,

    -0.0839096,  -0.299329,   0.108746,    0.109808,
    0.10084175,  -0.06416984, 0.28936723,  0.0026358,

    0.419114,    -0.237824,   -0.422627,   0.175115,
    -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,

    0.36726,     -0.522303,   -0.456502,   -0.175475,
    0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
};

static float svdf_golden_output_rank_2[] = {
    -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
    0.1141196,   0.12965347,  -0.12652366, 0.01007236,

    -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
    0.10132131,  -0.06143532, -0.00924693, 0.10084561,

    0.01257364,  0.0506071,   -0.19287863, -0.07162561,
    -0.02033747, 0.22673416,  0.15487903,  0.02525555,

    -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
    0.09607603,  -0.0141301,  -0.08995658, 0.12867066,

    -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
    0.00331409,  0.11167502,  0.02218599,  -0.07309391,

    0.09593632,  -0.28361851, -0.0773851,  0.17199151,
    -0.00075242, 0.33691186,  -0.1536046,  0.16572715,

    -0.27916506, -0.27626723, 0.42615682,  0.3225764,
    -0.37472126, -0.55655634, -0.05013514, 0.289112,

    -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
    0.00732617,  0.46737891,  0.26449674,  0.24888524,

    -0.17225097, -0.54660404, -0.38795233, 0.08389944,
    0.07736043,  -0.28260678, 0.15666828,  1.14949894,

    -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
    0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
};

class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
 public:
  BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
                  int rank,
                  TensorType weights_feature_type = TensorType_FLOAT32,
                  TensorType weights_time_type = TensorType_FLOAT32)
      : batches_(batches),
        units_(units),
        input_size_(input_size),
        memory_size_(memory_size),
        rank_(rank) {
    input_ = AddInput(TensorType_FLOAT32);
    weights_feature_ = AddInput(weights_feature_type);
    weights_time_ = AddInput(weights_time_type);
    // TODO(b/121383394) : figure out why optional bias causes TFLite segfault
    // when using NNAPI delegate.
    bias_ = AddInput(TensorType_FLOAT32);
    const int num_filters = units * rank;
    activation_state_ = AddVariableInput(
        TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}});
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(
        BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
        CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
    BuildInterpreterWithNNAPI({
        {batches_, input_size_},              // input tensor
        {units_ * rank, input_size_},         // weights_feature tensor
        {units_ * rank, memory_size_},        // weights_time tensor
        {units_},                             // bias tensor
        {batches, memory_size * num_filters}  // activation_state tensor
    });
    // TODO(b/121383394) : remove once the optional bias bug is fixed.
    PopulateTensor(bias_, std::vector<float>(units_));
  }

  // Populates the weights_feature tensor.
  void SetWeightsFeature(std::initializer_list<float> f) {
    PopulateTensor(weights_feature_, f);
  }

  // Populates the weights_time tensor.
  void SetWeightsTime(std::initializer_list<float> f) {
    PopulateTensor(weights_time_, f);
  }

  // Populates the input tensor.
  void SetInput(int offset, float* begin, float* end) {
    PopulateTensor(input_, offset, begin, end);
  }

  // Extracts the output tensor from the SVDF op.
  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

  int input_size() { return input_size_; }
  int num_units() { return units_; }
  int num_batches() { return batches_; }

 protected:
  int input_;
  int weights_feature_;
  int weights_time_;
  int bias_;
  int activation_state_;
  int output_;

  int batches_;
  int units_;
  int input_size_;
  int memory_size_;
  int rank_;
};

class SVDFOpModel : public BaseSVDFOpModel {
 public:
  using BaseSVDFOpModel::BaseSVDFOpModel;
};

class SVDFOpTest : public ::testing::Test {
 protected:
  void VerifyGoldens(float golden_input[], float golden_output[],
                     int golden_size, BaseSVDFOpModel* svdf,
                     float tolerance = 1e-5) {
    const int svdf_num_batches = svdf->num_batches();
    const int svdf_input_size = svdf->input_size();
    const int svdf_num_units = svdf->num_units();
    const int input_sequence_size =
        golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
    // Going over each input batch, setting the input tensor, invoking the SVDF
    // op and checking the output with the expected golden values.
    for (int i = 0; i < input_sequence_size; i++) {
      float* batch_start =
          golden_input + i * svdf_input_size * svdf_num_batches;
      float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
      svdf->SetInput(0, batch_start, batch_end);

      ASSERT_EQ(svdf->Invoke(), kTfLiteOk);

      const float* golden_start =
          golden_output + i * svdf_num_units * svdf_num_batches;
      const float* golden_end =
          golden_start + svdf_num_units * svdf_num_batches;
      std::vector<float> expected;
      expected.insert(expected.end(), golden_start, golden_end);

      EXPECT_THAT(svdf->GetOutput(),
                  ElementsAreArray(ArrayFloatNear(expected, tolerance)));
    }
  }
};

TEST_F(SVDFOpTest, BlackBoxTestRank1) {
  SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
                   /*memory_size=*/10, /*rank=*/1);
  svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
                          0.22197971, 0.12416199, 0.27901134, 0.27557442,
                          0.3905206, -0.36137494, -0.06634006, -0.10640851});

  svdf.SetWeightsTime(
      {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
       0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,

       0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
       -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,

       -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
       0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,

       -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
       -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});

  VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
                &svdf);
}

TEST_F(SVDFOpTest, BlackBoxTestRank2) {
  SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
                   /*memory_size=*/10, /*rank=*/2);
  svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
                          0.12416199,  0.15785322,  0.27901134,  0.3905206,
                          0.21931258,  -0.36137494, -0.10640851, 0.31053296,
                          -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
                          0.15294972,  0.38031587,  0.27557442,  0.39635518,
                          -0.21580373, -0.06634006, -0.02702999, 0.27072677});

  svdf.SetWeightsTime(
      {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
       0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,

       0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
       -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,

       -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
       0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,

       -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
       -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,

       -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
       0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,

       -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
       0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,

       -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
       -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,

       0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
       0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});

  VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
                &svdf);
}

class LSTMOpModel : public SingleOpModelWithNNAPI {
 public:
  LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
              bool use_peephole, bool use_projection_weights,
              bool use_projection_bias, float cell_clip, float proj_clip,
              const std::vector<std::vector<int>>& input_shapes,
              const TensorType weight_type)
      : n_batch_(n_batch),
        n_input_(n_input),
        n_cell_(n_cell),
        n_output_(n_output),
        weight_type_(weight_type) {
    input_ = AddInput(TensorType_FLOAT32);

    if (use_cifg) {
      input_to_input_weights_ = AddNullInput();
    } else {
      input_to_input_weights_ = AddInput(weight_type);
    }

    input_to_forget_weights_ = AddInput(weight_type);
    input_to_cell_weights_ = AddInput(weight_type);
    input_to_output_weights_ = AddInput(weight_type);

    if (use_cifg) {
      recurrent_to_input_weights_ = AddNullInput();
    } else {
      recurrent_to_input_weights_ = AddInput(weight_type);
    }

    recurrent_to_forget_weights_ = AddInput(weight_type);
    recurrent_to_cell_weights_ = AddInput(weight_type);
    recurrent_to_output_weights_ = AddInput(weight_type);

    if (use_peephole) {
      if (use_cifg) {
        cell_to_input_weights_ = AddNullInput();
      } else {
        cell_to_input_weights_ = AddInput(weight_type);
      }
      cell_to_forget_weights_ = AddInput(weight_type);
      cell_to_output_weights_ = AddInput(weight_type);
    } else {
      cell_to_input_weights_ = AddNullInput();
      cell_to_forget_weights_ = AddNullInput();
      cell_to_output_weights_ = AddNullInput();
    }

    if (use_cifg) {
      input_gate_bias_ = AddNullInput();
    } else {
      input_gate_bias_ = AddInput(TensorType_FLOAT32);
    }
    forget_gate_bias_ = AddInput(TensorType_FLOAT32);
    cell_bias_ = AddInput(TensorType_FLOAT32);
    output_gate_bias_ = AddInput(TensorType_FLOAT32);

    if (use_projection_weights) {
      projection_weights_ = AddInput(weight_type);
      if (use_projection_bias) {
        projection_bias_ = AddInput(TensorType_FLOAT32);
      } else {
        projection_bias_ = AddNullInput();
      }
    } else {
      projection_weights_ = AddNullInput();
      projection_bias_ = AddNullInput();
    }

    // Adding the 2 input state tensors.
    input_activation_state_ = AddVariableInput(TensorType_FLOAT32);
    input_cell_state_ = AddVariableInput(TensorType_FLOAT32);

    const bool use_layer_norm = input_shapes.size() > 20;
    // Layer norm weights.
    if (use_layer_norm) {
      const int kInputLayerNormCoeffsIndex = 20;
      const int kForgetLayerNormCoeffsIndex = 21;
      const int kCellLayerNormCoeffsIndex = 22;
      const int kOutputLayerNormCoeffsIndex = 23;

      if (use_cifg) {
        input_layer_norm_coefficients_ = AddNullInput();
      } else {
        input_layer_norm_coefficients_ =
            AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
      }
      forget_layer_norm_coefficients_ =
          AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
      cell_layer_norm_coefficients_ =
          AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
      output_layer_norm_coefficients_ =
          AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
    }

    output_ = AddOutput(TensorType_FLOAT32);

    SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
                 CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
                                   cell_clip, proj_clip)
                     .Union());
    BuildInterpreterWithNNAPI(input_shapes);
  }

  void SetInputToInputWeights(const std::vector<float>& f) {
    SetData(input_to_input_weights_, weight_type_, f);
  }

  void SetInputToForgetWeights(const std::vector<float>& f) {
    SetData(input_to_forget_weights_, weight_type_, f);
  }

  void SetInputToCellWeights(const std::vector<float>& f) {
    SetData(input_to_cell_weights_, weight_type_, f);
  }

  void SetInputToOutputWeights(const std::vector<float>& f) {
    SetData(input_to_output_weights_, weight_type_, f);
  }

  void SetRecurrentToInputWeights(const std::vector<float>& f) {
    SetData(recurrent_to_input_weights_, weight_type_, f);
  }

  void SetRecurrentToForgetWeights(const std::vector<float>& f) {
    SetData(recurrent_to_forget_weights_, weight_type_, f);
  }

  void SetRecurrentToCellWeights(const std::vector<float>& f) {
    SetData(recurrent_to_cell_weights_, weight_type_, f);
  }

  void SetRecurrentToOutputWeights(const std::vector<float>& f) {
    SetData(recurrent_to_output_weights_, weight_type_, f);
  }

  void SetCellToInputWeights(const std::vector<float>& f) {
    SetData(cell_to_input_weights_, weight_type_, f);
  }

  void SetCellToForgetWeights(const std::vector<float>& f) {
    SetData(cell_to_forget_weights_, weight_type_, f);
  }

  void SetCellToOutputWeights(const std::vector<float>& f) {
    SetData(cell_to_output_weights_, weight_type_, f);
  }

  void SetInputGateBias(const std::vector<float>& f) {
    PopulateTensor(input_gate_bias_, f);
  }

  void SetForgetGateBias(const std::vector<float>& f) {
    PopulateTensor(forget_gate_bias_, f);
  }

  void SetCellBias(const std::vector<float>& f) {
    PopulateTensor(cell_bias_, f);
  }

  void SetOutputGateBias(const std::vector<float>& f) {
    PopulateTensor(output_gate_bias_, f);
  }

  void SetProjectionWeights(const std::vector<float>& f) {
    SetData(projection_weights_, weight_type_, f);
  }

  void SetProjectionBias(const std::vector<float>& f) {
    PopulateTensor(projection_bias_, f);
  }

  void SetInputLayerNormCoefficients(const std::vector<float>& f) {
    PopulateTensor(input_layer_norm_coefficients_, f);
  }

  void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
    PopulateTensor(forget_layer_norm_coefficients_, f);
  }

  void SetCellLayerNormCoefficients(const std::vector<float>& f) {
    PopulateTensor(cell_layer_norm_coefficients_, f);
  }

  void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
    PopulateTensor(output_layer_norm_coefficients_, f);
  }

  void SetInput(int offset, const float* begin, const float* end) {
    PopulateTensor(input_, offset, const_cast<float*>(begin),
                   const_cast<float*>(end));
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

  int num_inputs() { return n_input_; }
  int num_outputs() { return n_output_; }
  int num_cells() { return n_cell_; }
  int num_batches() { return n_batch_; }

 protected:
  int input_;
  int input_to_input_weights_;
  int input_to_forget_weights_;
  int input_to_cell_weights_;
  int input_to_output_weights_;

  int recurrent_to_input_weights_;
  int recurrent_to_forget_weights_;
  int recurrent_to_cell_weights_;
  int recurrent_to_output_weights_;

  int cell_to_input_weights_;
  int cell_to_forget_weights_;
  int cell_to_output_weights_;

  int input_gate_bias_;
  int forget_gate_bias_;
  int cell_bias_;
  int output_gate_bias_;

  int projection_weights_;
  int projection_bias_;
  int input_activation_state_;
  int input_cell_state_;

  int input_layer_norm_coefficients_;
  int forget_layer_norm_coefficients_;
  int cell_layer_norm_coefficients_;
  int output_layer_norm_coefficients_;

  int output_;
  int output_state_;
  int cell_state_;

  int n_batch_;
  int n_input_;
  int n_cell_;
  int n_output_;

 private:
  const TensorType weight_type_;

  int AddLayerNormCoeffsTensor(
      int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
    if (input_shapes[tensor_index][0] != 0) {
      return AddInput(TensorType_FLOAT32);
    } else {
      return AddNullInput();
    }
  }
};

class BaseLstmTest : public ::testing::Test {
 protected:
  // Weights of the LSTM model. Some are optional.
  std::vector<float> input_to_input_weights_;
  std::vector<float> input_to_cell_weights_;
  std::vector<float> input_to_forget_weights_;
  std::vector<float> input_to_output_weights_;
  std::vector<float> input_gate_bias_;
  std::vector<float> cell_gate_bias_;
  std::vector<float> forget_gate_bias_;
  std::vector<float> output_gate_bias_;
  std::vector<float> recurrent_to_input_weights_;
  std::vector<float> recurrent_to_cell_weights_;
  std::vector<float> recurrent_to_forget_weights_;
  std::vector<float> recurrent_to_output_weights_;
  std::vector<float> cell_to_input_weights_;
  std::vector<float> cell_to_forget_weights_;
  std::vector<float> cell_to_output_weights_;
  std::vector<float> projection_weights_;
  std::vector<float> input_layer_norm_coefficients_;
  std::vector<float> forget_layer_norm_coefficients_;
  std::vector<float> cell_layer_norm_coefficients_;
  std::vector<float> output_layer_norm_coefficients_;

  // LSTM input is stored as num_batch x num_inputs vector.
  std::vector<std::vector<float>> lstm_input_;
  // LSTM output is stored as num_batch x num_outputs vector.
  std::vector<std::vector<float>> lstm_golden_output_;

  // Compares output up to tolerance to the result of the lstm given the input.
  void VerifyGoldens(const std::vector<std::vector<float>>& input,
                     const std::vector<std::vector<float>>& output,
                     LSTMOpModel* lstm, float tolerance = 1e-5) {
    const int num_batches = input.size();
    EXPECT_GT(num_batches, 0);
    const int num_inputs = lstm->num_inputs();
    EXPECT_GT(num_inputs, 0);
    const int input_sequence_size = input[0].size() / num_inputs;
    EXPECT_GT(input_sequence_size, 0);
    for (int i = 0; i < input_sequence_size; ++i) {
      for (int b = 0; b < num_batches; ++b) {
        const float* batch_start = input[b].data() + i * num_inputs;
        const float* batch_end = batch_start + num_inputs;

        lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
      }

      ASSERT_EQ(lstm->Invoke(), kTfLiteOk);

      const int num_outputs = lstm->num_outputs();
      std::vector<float> expected;
      for (int b = 0; b < num_batches; ++b) {
        const float* golden_start_batch = output[b].data() + i * num_outputs;
        const float* golden_end_batch = golden_start_batch + num_outputs;
        expected.insert(expected.end(), golden_start_batch, golden_end_batch);
      }
      EXPECT_THAT(lstm->GetOutput(),
                  ElementsAreArray(ArrayFloatNear(expected, tolerance)));
    }
  }
};

class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
  void SetUp() override {
    input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
                               -0.34550029, 0.04266912,  -0.15680569,
                               -0.34856534, 0.43890524};
    input_to_cell_weights_ = {-0.50013041, 0.1370284,  0.11810488, 0.2013163,
                              -0.20583314, 0.44344562, 0.22077113, -0.29909778};
    input_to_forget_weights_ = {0.09701663,  0.20334584,  -0.50592935,
                                -0.31343272, -0.40032279, 0.44781327,
                                0.01387155,  -0.35593212};
    input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
                                0.40525138,  0.44272184,  0.03897077,
                                -0.1556896,  0.19487578};
    input_gate_bias_ = {0., 0., 0., 0.};
    cell_gate_bias_ = {0., 0., 0., 0.};
    forget_gate_bias_ = {1., 1., 1., 1.};
    output_gate_bias_ = {0., 0., 0., 0.};

    recurrent_to_input_weights_ = {
        -0.0063535,  -0.2042388,  0.31454784,  -0.35746509,
        0.28902304,  0.08183324,  -0.16555229, 0.02286911,
        -0.13566875, 0.03034258,  0.48091322,  -0.12528998,
        0.24077177,  -0.51332325, -0.33502164, 0.10629296};

    recurrent_to_cell_weights_ = {
        -0.3407414,  0.24443203,  -0.2078532,  0.26320225,
        0.05695659,  -0.00123841, -0.4744786,  -0.35869038,
        -0.06418842, -0.13502428, -0.501764,   0.22830659,
        -0.46367589, 0.26016325,  -0.03894562, -0.16368064};

    recurrent_to_forget_weights_ = {
        -0.48684245, -0.06655136, 0.42224967,  0.2112639,
        0.27654213,  0.20864892,  -0.07646349, 0.45877004,
        0.00141793,  -0.14609534, 0.36447752,  0.09196436,
        0.28053468,  0.01560611,  -0.20127171, -0.01140004};

    recurrent_to_output_weights_ = {
        0.43385774,  -0.17194885, 0.2718237,  0.09215671,
        0.24107647,  -0.39835793, 0.18212086, 0.01301402,
        0.48572797,  -0.50656658, 0.20047462, -0.20607421,
        -0.51818722, -0.15390486, 0.0468148,  0.39922136};

    lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
    lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
                            -0.03716109, 0.12507336, 0.41193449, -0.20860538,
                            -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
  }
};

TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
  const int n_batch = 1;
  const int n_input = 2;
  // n_cell and n_output have the same size when there is no projection.
  const int n_cell = 4;
  const int n_output = 4;

  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
                   /*use_cifg=*/false, /*use_peephole=*/false,
                   /*use_projection_weights=*/false,
                   /*use_projection_bias=*/false,
                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
                   {
                       {n_batch, n_input},  // input tensor

                       {n_cell, n_input},  // input_to_input_weight tensor
                       {n_cell, n_input},  // input_to_forget_weight tensor
                       {n_cell, n_input},  // input_to_cell_weight tensor
                       {n_cell, n_input},  // input_to_output_weight tensor

                       {n_cell, n_output},  // recurrent_to_input_weight_tensor
                       {n_cell, n_output},  // recurrent_to_forget_weight_tensor
                       {n_cell, n_output},  // recurrent_to_cell_weight_tensor
                       {n_cell, n_output},  // recurrent_to_output_weight_tensor

                       {0},  // cell_to_input_weight tensor
                       {0},  // cell_to_forget_weight tensor
                       {0},  // cell_to_output_weight tensor

                       {n_cell},  // input_gate_bias tensor
                       {n_cell},  // forget_gate_bias tensor
                       {n_cell},  // cell_bias tensor
                       {n_cell},  // output_gate_bias tensor

                       {0, 0},  // projection_weight tensor
                       {0},     // projection_bias tensor

                       {n_batch, n_output},  // activation_state tensor
                       {n_batch, n_cell},    // cell_state tensor
                   },
                   /*weight_type=*/TensorType_FLOAT32);

  lstm.SetInputToInputWeights(input_to_input_weights_);
  lstm.SetInputToCellWeights(input_to_cell_weights_);
  lstm.SetInputToForgetWeights(input_to_forget_weights_);
  lstm.SetInputToOutputWeights(input_to_output_weights_);

  lstm.SetInputGateBias(input_gate_bias_);
  lstm.SetCellBias(cell_gate_bias_);
  lstm.SetForgetGateBias(forget_gate_bias_);
  lstm.SetOutputGateBias(output_gate_bias_);

  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}

class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
    : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};

TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
       LstmBlackBoxTest) {
  const int n_batch = 1;
  const int n_input = 2;
  // n_cell and n_output have the same size when there is no projection.
  const int n_cell = 4;
  const int n_output = 4;

  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
                   /*use_cifg=*/false, /*use_peephole=*/false,
                   /*use_projection_weights=*/false,
                   /*use_projection_bias=*/false,
                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
                   {
                       {n_batch, n_input},  // input tensor

                       {n_cell, n_input},  // input_to_input_weight tensor
                       {n_cell, n_input},  // input_to_forget_weight tensor
                       {n_cell, n_input},  // input_to_cell_weight tensor
                       {n_cell, n_input},  // input_to_output_weight tensor

                       {n_cell, n_output},  // recurrent_to_input_weight_tensor
                       {n_cell, n_output},  // recurrent_to_forget_weight_tensor
                       {n_cell, n_output},  // recurrent_to_cell_weight_tensor
                       {n_cell, n_output},  // recurrent_to_output_weight_tensor

                       {0},  // cell_to_input_weight tensor
                       {0},  // cell_to_forget_weight tensor
                       {0},  // cell_to_output_weight tensor

                       {n_cell},  // input_gate_bias tensor
                       {n_cell},  // forget_gate_bias tensor
                       {n_cell},  // cell_bias tensor
                       {n_cell},  // output_gate_bias tensor

                       {0, 0},  // projection_weight tensor
                       {0},     // projection_bias tensor

                       {n_batch, n_output},  // activation_state tensor
                       {n_batch, n_cell},    // cell_state tensor

                       {0},  // input_layer_norm_coefficient tensor
                       {0},  // forget_layer_norm_coefficient tensor
                       {0},  // cell_layer_norm_coefficient tensor
                       {0},  // output_layer_norm_coefficient tensor
                   },
                   /*weight_type=*/TensorType_FLOAT32);

  lstm.SetInputToInputWeights(input_to_input_weights_);
  lstm.SetInputToCellWeights(input_to_cell_weights_);
  lstm.SetInputToForgetWeights(input_to_forget_weights_);
  lstm.SetInputToOutputWeights(input_to_output_weights_);

  lstm.SetInputGateBias(input_gate_bias_);
  lstm.SetCellBias(cell_gate_bias_);
  lstm.SetForgetGateBias(forget_gate_bias_);
  lstm.SetOutputGateBias(output_gate_bias_);

  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}

class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
  void SetUp() override {
    input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
                              0.05100781,  0.04717243,  0.48944736,
                              -0.38535351, -0.17212132};

    input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
                                -0.3633365,  -0.22755712, 0.28253698,
                                0.24407166,  0.33826375};

    input_to_output_weights_ = {0.10725588,  -0.02335852, -0.55932593,
                                -0.09426838, -0.44257352, 0.54939759,
                                0.01533556,  0.42751634};
    cell_gate_bias_ = {0., 0., 0., 0.};
    forget_gate_bias_ = {1., 1., 1., 1.};
    output_gate_bias_ = {0., 0., 0., 0.};

    recurrent_to_cell_weights_ = {
        0.54066205,  -0.32668582, -0.43562764, -0.56094903,
        0.42957711,  0.01841056,  -0.32764608, -0.33027974,
        -0.10826075, 0.20675004,  0.19069612,  -0.03026325,
        -0.54532051, 0.33003211,  0.44901288,  0.21193194};

    recurrent_to_forget_weights_ = {
        -0.13832897, -0.0515101,  -0.2359007, -0.16661474,
        -0.14340827, 0.36986142,  0.23414481, 0.55899,
        0.10798943,  -0.41174671, 0.17751795, -0.34484994,
        -0.35874045, -0.11352962, 0.27268326, 0.54058349};

    recurrent_to_output_weights_ = {
        0.41613156, 0.42610586,  -0.16495961, -0.5663873,
        0.30579174, -0.05115908, -0.33941799, 0.23364776,
        0.11178309, 0.09481031,  -0.26424935, 0.46261835,
        0.50248802, 0.26114327,  -0.43736315, 0.33149987};

    cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
                               0.31544167};
    cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
                               -0.77109635};

    lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
    lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
                            -0.42312205, -0.01218222, 0.24201041, -0.08124574,
                            -0.358325, -0.04621704, 0.21641694, -0.06471302}};
  }
};

TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
  const int n_batch = 1;
  const int n_input = 2;
  // n_cell and n_output have the same size when there is no projection.
  const int n_cell = 4;
  const int n_output = 4;

  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
                   /*use_cifg=*/true, /*use_peephole=*/true,
                   /*use_projection_weights=*/false,
                   /*use_projection_bias=*/false,
                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
                   {
                       {n_batch, n_input},  // input tensor

                       {0, 0},             // input_to_input_weight tensor
                       {n_cell, n_input},  // input_to_forget_weight tensor
                       {n_cell, n_input},  // input_to_cell_weight tensor
                       {n_cell, n_input},  // input_to_output_weight tensor

                       {0, 0},              // recurrent_to_input_weight tensor
                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
                       {n_cell, n_output},  // recurrent_to_output_weight tensor

                       {0},       // cell_to_input_weight tensor
                       {n_cell},  // cell_to_forget_weight tensor
                       {n_cell},  // cell_to_output_weight tensor

                       {0},       // input_gate_bias tensor
                       {n_cell},  // forget_gate_bias tensor
                       {n_cell},  // cell_bias tensor
                       {n_cell},  // output_gate_bias tensor

                       {0, 0},  // projection_weight tensor
                       {0},     // projection_bias tensor

                       {n_batch, n_output},  // activation_state tensor
                       {n_batch, n_cell},    // cell_state tensor
                   },
                   /*weight_type=*/TensorType_FLOAT32);

  lstm.SetInputToCellWeights(input_to_cell_weights_);
  lstm.SetInputToForgetWeights(input_to_forget_weights_);
  lstm.SetInputToOutputWeights(input_to_output_weights_);

  lstm.SetCellBias(cell_gate_bias_);
  lstm.SetForgetGateBias(forget_gate_bias_);
  lstm.SetOutputGateBias(output_gate_bias_);

  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
  lstm.SetCellToOutputWeights(cell_to_output_weights_);

  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}

class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
  void SetUp() override {
    input_to_input_weights_ = {
        0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
        0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
        -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
        -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
        -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
        -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
        -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
        0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
        0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
        0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
        -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
        0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
        -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
        -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
        -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
        0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
        -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
        -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
        -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
        -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677};

    input_to_forget_weights_ = {
        -0.0018401089, -0.004852237, 0.03698424,    0.014181704,
        0.028273236,   -0.016726194, -0.05249759,   -0.10204261,
        0.00861066,    -0.040979505, -0.009899187,  0.01923892,
        -0.028177269,  -0.08535103,  -0.14585495,   0.10662567,
        -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
        0.0030784295,  0.076784775,  0.07463696,    0.094531395,
        0.0814421,     -0.12257899,  -0.033945758,  -0.031303465,
        0.045630626,   0.06843887,   -0.13492945,   -0.012480007,
        -0.0811829,    -0.07224499,  -0.09628791,   0.045100946,
        0.0012300825,  0.013964662,  0.099372394,   0.02543059,
        0.06958324,    0.034257296,  0.0482646,     0.06267997,
        0.052625068,   0.12784666,   0.07077897,    0.025725935,
        0.04165009,    0.07241905,   0.018668644,   -0.037377294,
        -0.06277783,   -0.08833636,  -0.040120605,  -0.011405586,
        -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
        0.05483423,    0.11449111,   0.11289652,    0.10939839,
        0.13396506,    -0.08402166,  -0.01901462,   -0.044678304,
        -0.07720565,   0.014350063,  -0.11757958,   -0.0652038,
        -0.08185733,   -0.076754324, -0.092614375,  0.10405491,
        0.052960336,   0.035755895,  0.035839386,   -0.012540553,
        0.036881298,   0.02913376,   0.03420159,    0.05448447,
        -0.054523353,  0.02582715,   0.02327355,    -0.011857179,
        -0.0011980024, -0.034641717, -0.026125094,  -0.17582615,
        -0.15923657,   -0.27486774,  -0.0006143371, 0.0001771948,
        -8.470171e-05, 0.02651807,   0.045790765,   0.06956496};

    input_to_cell_weights_ = {
        -0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
        -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
        -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
        -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
        -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
        0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
        -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
        0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
        -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
        -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
        -0.025174323,  0.0396852,     0.081777506,   0.06157468,
        0.10210095,    -0.009658194,  0.046511717,   0.03603906,
        0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
        0.053568836,   0.06408714,    0.12835667,    -0.008714329,
        -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
        -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
        -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
        -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
        -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
        -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
        0.05453865,    0.091149814,   0.06387331,    0.007518393,
        0.055960953,   0.069779344,   0.046411168,   0.10509911,
        0.07463894,    0.0075130584,  0.012850982,   0.04555431,
        0.056955688,   0.06555285,    0.050801456,   -0.009862683,
        0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042};

    input_to_output_weights_ = {
        -0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
        -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
        0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
        -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
        -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
        0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
        -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
        -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
        -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
        -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
        0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
        0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
        0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
        -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
        0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
        0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
        -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
        0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
        -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
        -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956};

    input_gate_bias_ = {0.02234832,   0.14757581,  0.18176508,  0.10380666,
                        0.053110216,  -0.06928846, -0.13942584, -0.11816189,
                        0.19483899,   0.03652339,  -0.10250295, 0.036714908,
                        -0.18426876,  0.036065217, 0.21810818,  0.02383196,
                        -0.043370757, 0.08690144,  -0.04444982, 0.00030581196};

    forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
                         0.11098921,  0.15378423,   0.09263801,  0.09790885,
                         0.09508917,  0.061199076,  0.07665568,  -0.015443159,
                         -0.03499149, 0.046190713,  0.08895977,  0.10899629,
                         0.40694186,  0.06030037,   0.012413437, -0.06108739};

    cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
                       -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
                       -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
                       -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
                       0.016178843,  0.1749513,    0.13975595,   0.92058027};

    output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469,   0.12648113,
                         0.027195795, 0.35373217,    -0.018957434, 0.008907322,
                         -0.0762701,  0.12018895,    0.04216877,   0.0022856654,
                         0.040952638, 0.3147856,     0.08225149,   -0.057416286,
                         -0.14995944, -0.008040261,  0.13208859,   0.029760877};

    recurrent_to_input_weights_ = {
        -0.001374326,   -0.078856036,   0.10672688,    0.029162422,
        -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
        -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
        -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
        0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
        0.08981,        -0.045407712,   0.08682226,    -0.06867011,
        -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
        0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
        -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
        0.009352075,    0.22920375,     0.0016303885,  0.11583097,
        -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
        0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
        -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
        0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
        -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
        -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
        -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
        -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
        -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
        0.01068115,     0.032956902,    0.022433773,   0.0026891115,
        0.08944216,     -0.0685835,     0.010513544,   0.07228705,
        0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
        0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
        0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
        -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
        -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
        0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
        -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
        -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
        -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
        -0.017142897,   0.03312627,     0.009205989,   0.024138335,
        -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
        -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
        0.0365468,      0.07590991,     0.08838724,    0.021681072,
        -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
        0.023646897,    -0.095322326,   0.02233014,    0.09756986,
        -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
        -0.09801813,    0.019894179,    0.08502348,    0.004032281,
        0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
        -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
        -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
        0.010889619,    0.0047078193,   0.038385306,   0.08540671,
        -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
        0.015963363,    0.00871737,     0.060130805,   0.028611384,
        0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
        0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
        0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
        0.019899689,    0.006106124,    -0.027092824,  0.0786356,
        0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
        -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
        -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
        -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
        -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
        -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
        0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
        0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
        -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
        0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
        0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
        0.058618143,    -0.08598433,    0.00972939,    0.023867095,
        -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
        -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
        0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
        -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
        -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
        0.06358255,     0.18531723,     0.07759293,    0.12006465,
        0.1305557,      0.058638252,    -0.03393652,   0.09622831,
        -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
        -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
        0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
        0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
        0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
        0.08184801,     -0.019164372,   0.06791302,    0.034257166,
        -0.10307039,    0.021943003,    0.046745934,   0.0790918,
        -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
        -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
        -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
        0.026351685,    0.012641483,    0.07466548,    0.044301085,
        -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
        -0.04106223,    -0.028126027,   0.028473156,   0.10467447};

    recurrent_to_cell_weights_ = {
        -0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
        0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
        0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
        -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
        0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
        0.08089997,     0.05143358,    0.038261272,   0.03339287,
        -0.027673481,   0.044746667,   0.028349208,   0.020090483,
        -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
        -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
        -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
        0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
        -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
        -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
        0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
        0.010868644,    -0.031489216,  0.09525667,    0.013939797,
        0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
        -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
        0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
        0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
        -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
        0.02786344,     -0.014179351,  0.005264273,   0.14376344,
        0.015983658,    0.03406988,    -0.06939408,   0.040699873,
        0.02111075,     0.09669095,    0.041345075,   -0.08316494,
        -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
        0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
        -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
        0.06760663,     -0.027437469,  0.07216407,    0.06977076,
        -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
        0.043184172,    -0.037189785,  0.10420091,    0.00882477,
        -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
        0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
        0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
        -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
        0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
        -0.008264958,   0.042035464,   0.05891794,    0.029673764,
        0.0063542654,   0.044788733,   0.054816857,   0.062257513,
        -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
        -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
        -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
        -0.007376126,   0.003533447,   0.006570588,   0.056037236,
        0.12436656,     0.051817212,   0.028532185,   -0.08686856,
        0.11868599,     0.07663395,    -0.07323171,   0.03463402,
        -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
        0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
        0.023029093,    0.086124025,   0.006445803,   -0.03496501,
        0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
        -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
        0.09465633,     0.008115513,   -0.02171956,   0.08304309,
        0.071401566,    0.019622514,   0.032163795,   -0.004167056,
        0.02295182,     0.030739572,   0.056506045,   0.004612461,
        0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
        -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
        0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
        -0.0329582,     0.07922767,    0.029322514,   0.026405897,
        0.04207835,     -0.07073373,   0.063781224,   0.0859677,
        -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
        -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
        -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
        -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
        0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
        0.15978073,     0.10185836,    0.10298046,    -0.015476589,
        -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
        -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
        -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
        -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
        -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
        0.012962922,    -0.031234352,  0.07029052,    0.016418684,
        0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
        -0.054761406,   0.029065743,   0.052404847,   0.020238016,
        0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
        0.06262858,     0.009184685,   0.020785125,   -0.043904778,
        -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
        -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
        0.09232601,     -0.035886683,  0.06000002,    0.05229691,
        -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
        -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
        0.031502828,    0.036232427,   -0.031581745,  0.023051167,
        -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
        -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
        -0.008799762,   0.056595087,   0.0022273948,  0.055752404};

    recurrent_to_forget_weights_ = {
        -0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
        0.14811787,    0.10826372,    0.09471067,     0.03987225,
        -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
        0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
        0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
        -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
        -0.06193199,   0.055729095,   0.03736828,     0.020123724,
        0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
        -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
        -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
        0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
        -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
        -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
        -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
        0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
        0.013454138,   0.028934088,   0.01685226,     -0.086110644,
        -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
        0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
        0.03761666,    0.008096139,   -0.014454086,   0.014361001,
        -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
        -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
        0.060212336,   0.055259194,   0.06974018,     0.049454916,
        -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
        0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
        -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
        0.0042065294,  0.03881498,    0.019844765,    0.041858196,
        -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
        0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
        0.012321099,   0.082840554,   -0.029899208,   0.044217527,
        0.059855383,   0.07711018,    -0.045319796,   0.0948846,
        -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
        -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
        -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
        0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
        0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
        0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
        0.052958444,   0.07558703,    0.04817258,     0.044462286,
        -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
        0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
        0.024734668,   0.024614193,   -0.042046934,   0.09597743,
        -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
        -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
        -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
        0.04383914,    -0.046476185,  0.028658995,    0.060410924,
        0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
        0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
        0.015898481,   0.021362653,   -0.030262267,   0.016587038,
        -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
        -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
        0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
        -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
        -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
        -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
        -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
        0.15443139,    0.07684145,    0.036571592,    -0.035900835,
        -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
        -0.03858649,   0.01849943,    0.13872518,     0.01503974,
        0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
        -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
        0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
        0.05866852,    0.023947537,   -0.09445152,    0.035450947,
        0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
        0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
        0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
        0.051808182,   0.05875331,    -0.04536488,    0.001626336,
        -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
        0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
        -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
        -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
        0.11475477,    -0.023854522,  0.10071741,     0.0686208,
        -0.014250481,  0.034261297,   0.047418304,    0.08562733,
        -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
        0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
        0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
        0.014410365,   0.020995233,   0.17040324,     0.11511526,
        0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
        -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
        -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
        0.007076659,   0.10964551,    0.0409152,      0.008275321,
        -0.07283536,   0.07937492,    0.04192024,     -0.1075027};

    recurrent_to_output_weights_ = {
        0.025825322,   -0.05813119,   0.09495884,     -0.045984812,
        -0.01255415,   -0.0026479573, -0.08196161,    -0.054914974,
        -0.0046604523, -0.029587349,  -0.044576716,   -0.07480124,
        -0.082868785,  0.023254942,   0.027502948,    -0.0039728214,
        -0.08683098,   -0.08116779,   -0.014675607,   -0.037924774,
        -0.023314456,  -0.007401714,  -0.09255757,    0.029460307,
        -0.08829125,   -0.005139627,  -0.08989442,    -0.0555066,
        0.13596267,    -0.025062224,  -0.048351806,   -0.03850004,
        0.07266485,    -0.022414139,  0.05940088,     0.075114764,
        0.09597592,    -0.010211725,  -0.0049794707,  -0.011523867,
        -0.025980417,  0.072999895,   0.11091378,     -0.081685916,
        0.014416728,   0.043229222,   0.034178585,    -0.07530371,
        0.035837382,   -0.085607,     -0.007721233,   -0.03287832,
        -0.043848954,  -0.06404588,   -0.06632928,    -0.073643476,
        0.008214239,   -0.045984086,  0.039764922,    0.03474462,
        0.060612556,   -0.080590084,  0.049127717,    0.04151091,
        -0.030063879,  0.008801774,   -0.023021035,   -0.019558564,
        0.05158114,    -0.010947698,  -0.011825728,   0.0075720972,
        0.0699727,     -0.0039981045, 0.069350146,    0.08799282,
        0.016156472,   0.035502106,   0.11695009,     0.006217345,
        0.13392477,    -0.037875112,  0.025745004,    0.08940699,
        -0.00924166,   0.0046702605,  -0.036598757,   -0.08811812,
        0.10522024,    -0.032441203,  0.008176899,    -0.04454919,
        0.07058152,    0.0067963637,  0.039206743,    0.03259838,
        0.03725492,    -0.09515802,   0.013326398,    -0.052055415,
        -0.025676316,  0.03198509,    -0.015951829,   -0.058556724,
        0.036879618,   0.043357447,   0.028362012,    -0.05908629,
        0.0059240665,  -0.04995891,   -0.019187413,   0.0276265,
        -0.01628143,   0.0025863599,  0.08800015,     0.035250366,
        -0.022165963,  -0.07328642,   -0.009415526,   -0.07455109,
        0.11690406,    0.0363299,     0.07411125,     0.042103454,
        -0.009660886,  0.019076364,   0.018299393,    -0.046004917,
        0.08891175,    0.0431396,     -0.026327137,   -0.051502608,
        0.08979574,    -0.051670972,  0.04940282,     -0.07491107,
        -0.021240504,  0.022596184,   -0.034280192,   0.060163025,
        -0.058211457,  -0.051837247,  -0.01349775,    -0.04639988,
        -0.035936575,  -0.011681591,  0.064818054,    0.0073146066,
        -0.021745546,  -0.043124277,  -0.06471268,    -0.07053354,
        -0.029321948,  -0.05330136,   0.016933719,    -0.053782392,
        0.13747959,    -0.1361751,    -0.11569455,    0.0033329215,
        0.05693899,    -0.053219706,  0.063698,       0.07977434,
        -0.07924483,   0.06936997,    0.0034815092,   -0.007305279,
        -0.037325785,  -0.07251102,   -0.033633437,   -0.08677009,
        0.091591336,   -0.14165086,   0.021752775,    0.019683983,
        0.0011612234,  -0.058154266,  0.049996935,    0.0288841,
        -0.0024567875, -0.14345716,   0.010955264,    -0.10234828,
        0.1183656,     -0.0010731248, -0.023590032,   -0.072285876,
        -0.0724771,    -0.026382286,  -0.0014920527,  0.042667855,
        0.0018776858,  0.02986552,    0.009814309,    0.0733756,
        0.12289186,    0.018043943,   -0.0458958,     0.049412545,
        0.033632483,   0.05495232,    0.036686596,    -0.013781798,
        -0.010036754,  0.02576849,    -0.08307328,    0.010112348,
        0.042521734,   -0.05869831,   -0.071689695,   0.03876447,
        -0.13275425,   -0.0352966,    -0.023077697,   0.10285965,
        0.084736146,   0.15568255,    -0.00040734606, 0.027835453,
        -0.10292561,   -0.032401145,  0.10053256,     -0.026142767,
        -0.08271222,   -0.0030240538, -0.016368777,   0.1070414,
        0.042672627,   0.013456989,   -0.0437609,     -0.022309763,
        0.11576483,    0.04108048,    0.061026827,    -0.0190714,
        -0.0869359,    0.037901703,   0.0610107,      0.07202949,
        0.01675338,    0.086139716,   -0.08795751,    -0.014898893,
        -0.023771819,  -0.01965048,   0.007955471,    -0.043740474,
        0.03346837,    -0.10549954,   0.090567775,    0.042013682,
        -0.03176985,   0.12569028,    -0.02421228,    -0.029526481,
        0.023851605,   0.031539805,   0.05292009,     -0.02344001,
        -0.07811758,   -0.08834428,   0.10094801,     0.16594367,
        -0.06861939,   -0.021256343,  -0.041093912,   -0.06669611,
        0.035498552,   0.021757556,   -0.09302526,    -0.015403468,
        -0.06614931,   -0.051798206,  -0.013874718,   0.03630673,
        0.010412845,   -0.08077351,   0.046185967,    0.0035662893,
        0.03541868,    -0.094149634,  -0.034814864,   0.003128424,
        -0.020674974,  -0.03944324,   -0.008110165,   -0.11113267,
        0.08484226,    0.043586485,   0.040582247,    0.0968012,
        -0.065249965,  -0.028036479,  0.0050708856,   0.0017462453,
        0.0326779,     0.041296225,   0.09164146,     -0.047743853,
        -0.015952192,  -0.034451712,  0.084197424,    -0.05347844,
        -0.11768019,   0.085926116,   -0.08251791,    -0.045081906,
        0.0948852,     0.068401024,   0.024856757,    0.06978981,
        -0.057309967,  -0.012775832,  -0.0032452994,  0.01977615,
        -0.041040014,  -0.024264973,  0.063464895,    0.05431621,
    };

    cell_to_input_weights_ = {
        0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
        -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
        -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
        0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175};

    cell_to_forget_weights_ = {
        -0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
        -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
        -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
        0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355};

    cell_to_output_weights_ = {
        0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
        -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
        -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
        0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733};

    projection_weights_ = {
        -0.009802181, 0.09401916,   0.0717386,     -0.13895074,
        0.09641832,   0.060420845,  0.08539281,    0.054285463,
        0.061395317,  0.034448683,  -0.042991187,  0.019801661,
        -0.16840284,  -0.015726732, -0.23041931,   -0.024478018,
        -0.10959692,  -0.013875541, 0.18600968,    -0.061274476,
        0.0138165,    -0.08160894,  -0.07661644,   0.032372914,
        0.16169067,   0.22465782,   -0.03993472,   -0.004017731,
        0.08633481,   -0.28869787,  0.08682067,    0.17240396,
        0.014975425,  0.056431185,  0.031037588,   0.16702051,
        0.0077946745, 0.15140012,   0.29405436,    0.120285,
        -0.188994,    -0.027265169, 0.043389652,   -0.022061434,
        0.014777949,  -0.20203483,  0.094781205,   0.19100232,
        0.13987629,   -0.036132768, -0.06426278,   -0.05108664,
        0.13221376,   0.009441198,  -0.16715929,   0.15859416,
        -0.040437475, 0.050779544,  -0.022187516,  0.012166504,
        0.027685808,  -0.07675938,  -0.0055694645, -0.09444123,
        0.0046453946, 0.050794356,  0.10770313,    -0.20790008,
        -0.07149004,  -0.11425117,  0.008225835,   -0.035802525,
        0.14374903,   0.15262283,   0.048710253,   0.1847461,
        -0.007487823, 0.11000021,   -0.09542012,   0.22619456,
        -0.029149994, 0.08527916,   0.009043713,   0.0042746216,
        0.016261552,  0.022461696,  0.12689082,    -0.043589946,
        -0.12035478,  -0.08361797,  -0.050666027,  -0.1248618,
        -0.1275799,   -0.071875185, 0.07377272,    0.09944291,
        -0.18897448,  -0.1593054,   -0.06526116,   -0.040107165,
        -0.004618631, -0.067624845, -0.007576253,  0.10727444,
        0.041546922,  -0.20424393,  0.06907816,    0.050412357,
        0.00724631,   0.039827548,  0.12449835,    0.10747581,
        0.13708383,   0.09134148,   -0.12617786,   -0.06428341,
        0.09956831,   0.1208086,    -0.14676677,   -0.0727722,
        0.1126304,    0.010139365,  0.015571211,   -0.038128063,
        0.022913318,  -0.042050496, 0.16842307,    -0.060597885,
        0.10531834,   -0.06411776,  -0.07451711,   -0.03410368,
        -0.13393489,  0.06534304,   0.003620307,   0.04490757,
        0.05970546,   0.05197996,   0.02839995,    0.10434969,
        -0.013699693, -0.028353551, -0.07260381,   0.047201227,
        -0.024575593, -0.036445823, 0.07155557,    0.009672501,
        -0.02328883,  0.009533515,  -0.03606021,   -0.07421458,
        -0.028082801, -0.2678904,   -0.13221288,   0.18419984,
        -0.13012612,  -0.014588381, -0.035059117,  -0.04824723,
        0.07830115,   -0.056184657, 0.03277091,    0.025466874,
        0.14494097,   -0.12522776,  -0.098633975,  -0.10766018,
        -0.08317623,  0.08594209,   0.07749552,    0.039474737,
        0.1776665,    -0.07409566,  -0.0477268,    0.29323658,
        0.10801441,   0.1154011,    0.013952499,   0.10739139,
        0.10708251,   -0.051456142, 0.0074137426,  -0.10430189,
        0.10034707,   0.045594677,  0.0635285,     -0.0715442,
        -0.089667566, -0.10811871,  0.00026344223, 0.08298446,
        -0.009525053, 0.006585689,  -0.24567553,   -0.09450807,
        0.09648481,   0.026996298,  -0.06419476,   -0.04752702,
        -0.11063944,  -0.23441927,  -0.17608605,   -0.052156363,
        0.067035615,  0.19271925,   -0.0032889997, -0.043264326,
        0.09663576,   -0.057112187, -0.10100678,   0.0628376,
        0.04447668,   0.017961001,  -0.10094388,   -0.10190601,
        0.18335468,   0.10494553,   -0.052095775,  -0.0026118709,
        0.10539724,   -0.04383912,  -0.042349473,  0.08438151,
        -0.1947263,   0.02251204,   0.11216432,    -0.10307853,
        0.17351969,   -0.039091777, 0.08066188,    -0.00561982,
        0.12633002,   0.11335965,   -0.0088127935, -0.019777594,
        0.06864014,   -0.059751723, 0.016233567,   -0.06894641,
        -0.28651384,  -0.004228674, 0.019708522,   -0.16305895,
        -0.07468996,  -0.0855457,   0.099339016,   -0.07580735,
        -0.13775392,  0.08434318,   0.08330512,    -0.12131499,
        0.031935584,  0.09180414,   -0.08876437,   -0.08049874,
        0.008753825,  0.03498998,   0.030215185,   0.03907079,
        0.089751154,  0.029194152,  -0.03337423,   -0.019092513,
        0.04331237,   0.04299654,   -0.036394123,  -0.12915532,
        0.09793732,   0.07512415,   -0.11319543,   -0.032502122,
        0.15661901,   0.07671967,   -0.005491124,  -0.19379048,
        -0.218606,    0.21448623,   0.017840758,   0.1416943,
        -0.07051762,  0.19488361,   0.02664691,    -0.18104725,
        -0.09334311,  0.15026465,   -0.15493552,   -0.057762887,
        -0.11604192,  -0.262013,    -0.01391798,   0.012185008,
        0.11156489,   -0.07483202,  0.06693364,    -0.26151478,
        0.046425626,  0.036540434,  -0.16435726,   0.17338543,
        -0.21401681,  -0.11385144,  -0.08283257,   -0.069031075,
        0.030635102,  0.010969227,  0.11109743,    0.010919218,
        0.027526086,  0.13519906,   0.01891392,    -0.046839405,
        -0.040167913, 0.017953383,  -0.09700955,   0.0061885654,
        -0.07000971,  0.026893595,  -0.038844477,  0.14543656};

    lstm_input_ = {
        {// Batch0: 4 (input_sequence_size) * 5 (n_input)
         0.787926, 0.151646, 0.071352, 0.118426, 0.458058,   // step 0
         0.596268, 0.998386, 0.568695, 0.864524, 0.571277,   // step 1
         0.073204, 0.296072, 0.743333, 0.069199, 0.045348,   // step 2
         0.867394, 0.291279, 0.013714, 0.482521, 0.626339},  // step 3

        {// Batch1: 4 (input_sequence_size) * 5 (n_input)
         0.295743, 0.544053, 0.690064, 0.858138, 0.497181,  // step 0
         0.642421, 0.524260, 0.134799, 0.003639, 0.162482,  // step 1
         0.640394, 0.930399, 0.050782, 0.432485, 0.988078,  // step 2
         0.082922, 0.563329, 0.865614, 0.333232, 0.259916}  // step 3
    };

    lstm_golden_output_ = {
        {// Batch0: 4 (input_sequence_size) * 16 (n_output)
         -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
         -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
         -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
         0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
         -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
         -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
         0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
         0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
         0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
         0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
         -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
         -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
         0.0286833,   0.00824207,   0.0264887,   0.0305169},
        {// Batch1: 4 (input_sequence_size) * 16 (n_output)
         -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
         -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
         0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
         0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
         -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
         -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
         0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
         0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
         0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
         0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
         -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
         -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
         0.0412031,    0.0118723,   0.0239643,   0.0394009}};
  }
};

TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
  const int n_batch = 2;
  const int n_input = 5;
  const int n_cell = 20;
  const int n_output = 16;

  LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
                   /*use_cifg=*/false, /*use_peephole=*/true,
                   /*use_projection_weights=*/true,
                   /*use_projection_bias=*/false,
                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
                   {
                       {n_batch, n_input},  // input tensor

                       {n_cell, n_input},  // input_to_input_weight tensor
                       {n_cell, n_input},  // input_to_forget_weight tensor
                       {n_cell, n_input},  // input_to_cell_weight tensor
                       {n_cell, n_input},  // input_to_output_weight tensor

                       {n_cell, n_output},  // recurrent_to_input_weight tensor
                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
                       {n_cell, n_output},  // recurrent_to_output_weight tensor

                       {n_cell},  // cell_to_input_weight tensor
                       {n_cell},  // cell_to_forget_weight tensor
                       {n_cell},  // cell_to_output_weight tensor

                       {n_cell},  // input_gate_bias tensor
                       {n_cell},  // forget_gate_bias tensor
                       {n_cell},  // cell_bias tensor
                       {n_cell},  // output_gate_bias tensor

                       {n_output, n_cell},  // projection_weight tensor
                       {0},                 // projection_bias tensor

                       {n_batch, n_output},  // activation_state tensor
                       {n_batch, n_cell},    // cell_state tensor
                   },
                   /*weight_type=*/TensorType_FLOAT32);

  lstm.SetInputToInputWeights(input_to_input_weights_);
  lstm.SetInputToCellWeights(input_to_cell_weights_);
  lstm.SetInputToForgetWeights(input_to_forget_weights_);
  lstm.SetInputToOutputWeights(input_to_output_weights_);

  lstm.SetInputGateBias(input_gate_bias_);
  lstm.SetCellBias(cell_gate_bias_);
  lstm.SetForgetGateBias(forget_gate_bias_);
  lstm.SetOutputGateBias(output_gate_bias_);

  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  lstm.SetCellToInputWeights(cell_to_input_weights_);
  lstm.SetCellToForgetWeights(cell_to_forget_weights_);
  lstm.SetCellToOutputWeights(cell_to_output_weights_);

  lstm.SetProjectionWeights(projection_weights_);

  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}

class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
    : public BaseLstmTest {
  void SetUp() override {
    input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
                               0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
                               -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};

    input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
                                -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
                                -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};

    input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
                              -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
                              -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};

    input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
                                -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
                                -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};

    input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};

    forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};

    cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};

    output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};

    recurrent_to_input_weights_ = {-0.2, -0.3, 0.4,  0.1,  -0.5, 0.9,
                                   -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};

    recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
                                  -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};

    recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
                                    0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};

    recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
                                    -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};

    cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};

    cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};

    cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};

    input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
    forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
    cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
    output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};

    projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
                           0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};

    lstm_input_ = {
        {// Batch0: 3 (input_sequence_size) * 5 (n_input)
         0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
         0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
         0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2

        {// Batch1: 3 (input_sequence_size) * 5 (n_input)
         0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
         0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
         0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
    };
  }
};

TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
       LayerNormLstmBlackBoxTest) {
  const int n_batch = 2;
  const int n_input = 5;
  const int n_cell = 4;
  const int n_output = 3;
  const float ceil_clip = 0.0;
  const float proj_clip = 0.0;

  LSTMOpModel layer_norm_lstm(
      n_batch, n_input, n_cell, n_output,
      /*use_cifg=*/false, /*use_peephole=*/true,
      /*use_projection_weights=*/true,
      /*use_projection_bias=*/false, ceil_clip, proj_clip,
      {
          {n_batch, n_input},  // input tensor

          {n_cell, n_input},  // input_to_input_weight tensor
          {n_cell, n_input},  // input_to_forget_weight tensor
          {n_cell, n_input},  // input_to_cell_weight tensor
          {n_cell, n_input},  // input_to_output_weight tensor

          {n_cell, n_output},  // recurrent_to_input_weight tensor
          {n_cell, n_output},  // recurrent_to_forget_weight tensor
          {n_cell, n_output},  // recurrent_to_cell_weight tensor
          {n_cell, n_output},  // recurrent_to_output_weight tensor

          {n_cell},  // cell_to_input_weight tensor
          {n_cell},  // cell_to_forget_weight tensor
          {n_cell},  // cell_to_output_weight tensor

          {n_cell},  // input_gate_bias tensor
          {n_cell},  // forget_gate_bias tensor
          {n_cell},  // cell_bias tensor
          {n_cell},  // output_gate_bias tensor

          {n_output, n_cell},  // projection_weight tensor
          {0},                 // projection_bias tensor

          {n_batch, n_output},  // activation_state tensor
          {n_batch, n_cell},    // cell_state tensor

          {n_cell},  // input_layer_norm_coefficient tensor
          {n_cell},  // forget_layer_norm_coefficient tensor
          {n_cell},  // cell_layer_norm_coefficient tensor
          {n_cell},  // output_layer_norm_coefficient tensor
      },
      /*weight_type=*/TensorType_FLOAT32);

  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);

  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
  layer_norm_lstm.SetCellBias(cell_gate_bias_);
  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);

  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);

  layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
  layer_norm_lstm.SetForgetLayerNormCoefficients(
      forget_layer_norm_coefficients_);
  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
  layer_norm_lstm.SetOutputLayerNormCoefficients(
      output_layer_norm_coefficients_);

  layer_norm_lstm.SetProjectionWeights(projection_weights_);

  // Verify the final output.
  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
      {
          // Batch0: 3 (input_sequence_size) * 3 (n_output)
          0.0244077, 0.128027, -0.00170918,  // seq 0
          0.0137642, 0.140751, 0.0395835,    // seq 1
          -0.00459231, 0.155278, 0.0837377,  // seq 2
      },
      {
          // Batch1: 3 (input_sequence_size) * 3 (n_output)
          -0.00692428, 0.0848741, 0.063445,  // seq 0
          -0.00403912, 0.139963, 0.072681,   // seq 1
          0.00752706, 0.161903, 0.0561371,   // seq 2
      }};

  VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
}

class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
  void SetUp() override {
    input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
                                -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
                                -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
    input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
                              -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
                              -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
    input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
                                -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
                                -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};

    forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
    cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
    output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};

    recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
                                  -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
    recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
                                    0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
    recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
                                    -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};

    cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
    cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};

    forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
    cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
    output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
    projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
                           0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};

    lstm_input_ = {
        {// Batch0: 3 (input_sequence_size) * 5 (n_input)
         0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
         0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
         0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2

        {// Batch1: 3 (input_sequence_size) * 5 (n_input)
         0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
         0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
         0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
    };
  }
};

TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
       LayerNormLstmBlackBoxTest) {
  const int n_batch = 2;
  const int n_input = 5;
  const int n_cell = 4;
  const int n_output = 3;
  const float ceil_clip = 0.0;
  const float proj_clip = 0.0;

  LSTMOpModel layer_norm_lstm(
      n_batch, n_input, n_cell, n_output,
      /*use_cifg=*/true, /*use_peephole=*/true,
      /*use_projection_weights=*/true,
      /*use_projection_bias=*/false, ceil_clip, proj_clip,
      {
          {n_batch, n_input},  // input tensor

          {0, 0},             // input_to_input_weight tensor
          {n_cell, n_input},  // input_to_forget_weight tensor
          {n_cell, n_input},  // input_to_cell_weight tensor
          {n_cell, n_input},  // input_to_output_weight tensor

          {0, 0},              // recurrent_to_input_weight tensor
          {n_cell, n_output},  // recurrent_to_forget_weight tensor
          {n_cell, n_output},  // recurrent_to_cell_weight tensor
          {n_cell, n_output},  // recurrent_to_output_weight tensor

          {0},       // cell_to_input_weight tensor
          {n_cell},  // cell_to_forget_weight tensor
          {n_cell},  // cell_to_output_weight tensor

          {0},       // input_gate_bias tensor
          {n_cell},  // forget_gate_bias tensor
          {n_cell},  // cell_bias tensor
          {n_cell},  // output_gate_bias tensor

          {n_output, n_cell},  // projection_weight tensor
          {0},                 // projection_bias tensor

          {n_batch, n_output},  // activation_state tensor
          {n_batch, n_cell},    // cell_state tensor

          {0},       // input_layer_norm_coefficient tensor
          {n_cell},  // forget_layer_norm_coefficient tensor
          {n_cell},  // cell_layer_norm_coefficient tensor
          {n_cell},  // output_layer_norm_coefficient tensor
      },
      /*weight_type=*/TensorType_FLOAT32);

  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);

  layer_norm_lstm.SetCellBias(cell_gate_bias_);
  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);

  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);

  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);

  layer_norm_lstm.SetForgetLayerNormCoefficients(
      forget_layer_norm_coefficients_);
  layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
  layer_norm_lstm.SetOutputLayerNormCoefficients(
      output_layer_norm_coefficients_);

  layer_norm_lstm.SetProjectionWeights(projection_weights_);

  // Verify the final output.
  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
      {
          // Batch0: 3 (input_sequence_size) * 3 (n_output)
          0.02129706, 0.140816242, 0.0112733059,     // seq 0
          0.0132302344, 0.152308047, 0.0346313119,   // seq 1
          -0.0123688057, 0.165790111, 0.0893077999,  // seq 2
      },
      {
          // Batch1: 3 (input_sequence_size) * 3 (n_output)
          -0.0226350538, 0.0916948169, 0.0769175813,  // seq 0
          -0.0269966982, 0.149707705, 0.094149217,    // seq 1
          -0.0103429332, 0.173016444, 0.0720508844,   // seq 2
      }};

  VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
}

class BaseReduceOpModel : public SingleOpModelWithNNAPI {
 public:
  void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }

  template <class T>
  void SetInput(const std::vector<T>& data) {
    PopulateTensor(input_, data);
  }

  template <class T>
  std::vector<T> GetOutput() {
    return ExtractVector<T>(output_);
  }

  std::vector<float> GetDequantizedOutput() {
    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
                               GetScale(output_), GetZeroPoint(output_));
  }

  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }

  int Input() { return input_; }

 protected:
  int input_;
  int axis_;
  int output_;
};

// Model for the tests case where axis is a dynamic tensor.
class MeanOpDynamicModel : public BaseReduceOpModel {
 public:
  MeanOpDynamicModel(const TensorData& input, const TensorData& output,
                     const TensorData& axis, bool keep_dims) {
    input_ = AddInput(input);
    axis_ = AddInput(axis);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
                 CreateReducerOptions(builder_, keep_dims).Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }
};

TEST(DynamicFloatMeanOpTest, NotKeepDims) {
  std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                             9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
                             17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
  MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
                       {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
                       false);
  std::vector<int> axis = {1, 0, -3, -3};
  m.SetAxis(axis);
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
  EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
}

// Model for the tests case where axis is a const tensor.
class MeanOpConstModel : public BaseReduceOpModel {
 public:
  MeanOpConstModel(const TensorData& input, const TensorData& output,
                   std::initializer_list<int> axis_shape,
                   std::initializer_list<int> axis, bool keep_dims) {
    input_ = AddInput(input);
    axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
    output_ = AddOutput(output);
    SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
                 CreateReducerOptions(builder_, keep_dims).Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }
};

// Tests for reduce_mean
TEST(NNAPIDelegate, MeanFloatNotKeepDims) {
  std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                             9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
                             17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
  MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
                     {4}, {1, 0, -3, -3}, false);
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
  EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({12, 13}));
}

TEST(NNAPIDelegate, MeanFloatKeepDims) {
  std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
                             9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
                             17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
  MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
                     {2}, {0, 2}, true);
  m.SetInput(data);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
  EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({10.5, 12.5, 14.5}));
}

class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI {
 public:
  BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
                             std::initializer_list<int> weight_shape,
                             TensorType weight_type = TensorType_FLOAT32) {
    input_ = AddInput(TensorType_INT32);
    weight_ = AddInput(weight_type);
    output_ = AddOutput(TensorType_FLOAT32);
    SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({index_shape, weight_shape});
  }

  void SetInput(std::initializer_list<int> data) {
    PopulateTensor(input_, data);
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input_;
  int weight_;
  int output_;
};

class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
 public:
  using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;

  void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
    TfLiteTensor* tensor = interpreter_->tensor(weight_);
    int rows = tensor->dims->data[0];
    int columns = tensor->dims->data[1];
    int features = tensor->dims->data[2];
    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < columns; j++) {
        for (int k = 0; k < features; k++) {
          tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
        }
      }
    }
  }
};

TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) {
  EmbeddingLookupOpModel m({3}, {3, 2, 4});
  m.SetInput({1, 0, 2});
  m.Set3DWeightMatrix(
      [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({
                  1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
                  0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
                  2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
              }));
}

class HashtableLookupOpModel : public SingleOpModelWithNNAPI {
 public:
  HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
                         std::initializer_list<int> key_shape,
                         std::initializer_list<int> value_shape,
                         TensorType type) {
    lookup_ = AddInput(TensorType_INT32);
    key_ = AddInput(TensorType_INT32);
    value_ = AddInput(type);
    output_ = AddOutput(type);
    hit_ = AddOutput(TensorType_UINT8);
    SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({lookup_shape, key_shape, value_shape});
  }

  void SetLookup(std::initializer_list<int> data) {
    PopulateTensor<int>(lookup_, data);
  }

  void SetHashtableKey(std::initializer_list<int> data) {
    PopulateTensor<int>(key_, data);
  }

  void SetHashtableValue(const std::vector<string>& content) {
    PopulateStringTensor(value_, content);
  }

  void SetHashtableValue(const std::function<float(int)>& function) {
    TfLiteTensor* tensor = interpreter_->tensor(value_);
    int rows = tensor->dims->data[0];
    for (int i = 0; i < rows; i++) {
      tensor->data.f[i] = function(i);
    }
  }

  void SetHashtableValue(const std::function<float(int, int)>& function) {
    TfLiteTensor* tensor = interpreter_->tensor(value_);
    int rows = tensor->dims->data[0];
    int features = tensor->dims->data[1];
    for (int i = 0; i < rows; i++) {
      for (int j = 0; j < features; j++) {
        tensor->data.f[i * features + j] = function(i, j);
      }
    }
  }

  std::vector<string> GetStringOutput() {
    TfLiteTensor* output = interpreter_->tensor(output_);
    int num = GetStringCount(output);
    std::vector<string> result(num);
    for (int i = 0; i < num; i++) {
      auto ref = GetString(output, i);
      result[i] = string(ref.str, ref.len);
    }
    return result;
  }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
  std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }

 private:
  int lookup_;
  int key_;
  int value_;
  int output_;
  int hit_;
};

TEST(NNAPIDelegate, HashtableLookupTest2DInput) {
  HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);

  m.SetLookup({1234, -292, -11, 0});
  m.SetHashtableKey({-11, 0, 1234});
  m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 2.0, 2.1,  // 2-nd item
                                 0, 0,      // Not found
                                 0.0, 0.1,  // 0-th item
                                 1.0, 1.1,  // 1-st item
                             }));
  EXPECT_THAT(m.GetHit(), ElementsAreArray({
                              1,
                              0,
                              1,
                              1,
                          }));
}

TEST(NNAPIDelegate, HashtableLookupTest1DInput) {
  HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);

  m.SetLookup({1234, -292, -11, 0});
  m.SetHashtableKey({-11, 0, 1234});
  m.SetHashtableValue([](int i) { return i * i / 10.0f; });

  ASSERT_EQ(m.Invoke(), kTfLiteOk);

  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0.4,  // 2-nd item
                                 0,    // Not found
                                 0.0,  // 0-th item
                                 0.1,  // 1-st item
                             }));
  EXPECT_THAT(m.GetHit(), ElementsAreArray({
                              1,
                              0,
                              1,
                              1,
                          }));
}

// A base class of PRelu op model. It provides the constructor for
// FloatPReluOpModel and QuantizedPReluOpModel.
class PReluOpModel : public SingleOpModelWithNNAPI {
 public:
  PReluOpModel(const TensorData& input, const TensorData& alpha)
      : input_type_(input.type) {
    input_ = AddInput(input);
    alpha_ = AddInput(alpha);
    output_ = AddOutput({input.type, input.shape, input.min, input.max});
    SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
    BuildInterpreterWithNNAPI({GetShape(input_), GetShape(alpha_)});
  }

  void SetInput(std::initializer_list<float> data) {
    SetData(input_, input_type_, data);
  }

  void SetAlpha(std::initializer_list<float> data) {
    SetData(alpha_, input_type_, data);
  }

  std::vector<float> GetOutput() {
    std::vector<float> output;
    GetData(output_, input_type_, &output);
    return output;
  }

 protected:
  int input_;
  int alpha_;
  int output_;

  const TensorType input_type_;
};

TEST(NNAPIDelegate, PReluFloat) {
  PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
                 {TensorType_FLOAT32, {1, 1, 3}});

  m.SetInput({
      0.0f, 0.0f, 0.0f,     // Row 1, Column 1
      1.0f, 1.0f, 1.0f,     // Row 1, Column 2
      -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
      -2.0f, -2.0f, -2.0f,  // Row 1, Column 2
  });
  m.SetAlpha({0.0f, 1.0f, 2.0f});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
                                 0.0f, 0.0f, 0.0f,    // Row 1, Column 1
                                 1.0f, 1.0f, 1.0f,    // Row 1, Column 2
                                 0.0f, -1.0f, -2.0f,  // Row 2, Column 1
                                 0.0f, -2.0f, -4.0f,  // Row 1, Column 2
                             }));
}

TEST(NNAPIDelegate, PReluQuantized) {
  const float kMin = -1;
  const float kMax = 127.f / 128.f;
  PReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
                 {TensorType_UINT8, {1, 1, 3}, kMin, kMax});
  m.SetInput({
      0.0f, 0.0f, 0.0f,        // Row 1, Column 1
      0.5f, 0.5f, 0.5f,        // Row 1, Column 2
      -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
      -0.25f, -0.25f, -0.25f,  // Row 1, Column 2
  });
  m.SetAlpha({0.0f, 0.5f, -0.5f});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
                                 {
                                     0.0f, 0.0f, 0.0f,       // Row 1, Column 1
                                     0.5f, 0.5f, 0.5f,       // Row 1, Column 2
                                     0.0f, -0.5f, 0.5f,      // Row 2, Column 1
                                     0.0f, -0.125f, 0.125f,  // Row 1, Column 2
                                 },
                                 kQuantizedTolerance)));
}

// Tests case where paddings is a const tensor. Type T is the dtype.
template <typename T1>
class PadV2OpConstModel : public PadOpModel<T1> {
 public:
  PadV2OpConstModel(const TensorData& input,
                    std::initializer_list<int> paddings_shape,
                    std::initializer_list<int> paddings, T1 constant_values,
                    const TensorData& output) {
    this->input_ = this->AddInput(input);
    this->paddings_ =
        this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
    this->constant_values_ =
        this->AddConstInput(GetTensorType<T1>(), {constant_values}, {1});

    this->output_ = this->AddOutput(output);

    this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
                       CreatePadV2Options(this->builder_).Union());
    this->BuildInterpreterWithNNAPI({input.shape});
  }

  PadV2OpConstModel(const TensorData& input,
                    std::initializer_list<int> paddings_shape,
                    std::initializer_list<int> paddings,
                    const TensorData& constant_values,
                    const TensorData& output) {
    this->input_ = this->AddInput(input);
    this->paddings_ =
        this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
    this->constant_values_ = this->AddInput(constant_values);

    this->output_ = this->AddOutput(output);

    this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
                       CreatePadV2Options(this->builder_).Union());
    this->BuildInterpreterWithNNAPI({input.shape});
  }
};

// Test case where paddings is a non-const tensor.
template <typename RegularInputOutput>
class PadV2OpDynamicModel : public PadOpModel<RegularInputOutput> {
 public:
  PadV2OpDynamicModel(const TensorData& input,
                      std::initializer_list<int> paddings_shape,
                      RegularInputOutput constant_values,
                      const TensorData& output) {
    this->input_ = this->AddInput(input);
    this->paddings_ = this->AddInput(TensorType_INT32);
    this->constant_values_ = this->AddConstInput(
        GetTensorType<RegularInputOutput>(), {constant_values}, {1});
    this->output_ = this->AddOutput(output);

    this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
                       CreatePadV2Options(this->builder_).Union());
    this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
  }
  PadV2OpDynamicModel(const TensorData& input,
                      std::initializer_list<int> paddings_shape,
                      const TensorData& constant_values,
                      const TensorData& output) {
    this->input_ = this->AddInput(input);
    this->paddings_ = this->AddInput(TensorType_INT32);
    this->constant_values_ = this->AddInput(constant_values);
    this->output_ = this->AddOutput(output);

    this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
                       CreatePadV2Options(this->builder_).Union());
    this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
  }
};

TEST(PadV2OpTest, SimpleConstTest) {
  // Padding is represented as four 2-D lists representing above padding and
  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
  PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
                             {0, 0, 1, 1, 1, 1, 0, 0}, 0.0,
                             {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
                                                  4, 0, 0, 0, 0, 0}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(PadV2OpTest, SimpleConstFloat32ValuedTestUint8) {
  // Padding is represented as four 2-D lists representing above padding and
  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
  PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
                             {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
                                                  4, 5, 5, 5, 5, 5}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(PadV2OpTest, Simple4DConstFloat32ValuedTest) {
  // Padding is represented as four 2-D lists representing above padding and
  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
  PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2},
                             {0, 1, 0, 0, 0, 0, 0, 1}, 5, {TensorType_FLOAT32});
  m.SetInput({3, 3});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3, 5, 3, 5, 5, 5, 5, 5}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
}

TEST(PadV2OpTest, SimpleDynamicTest) {
  PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 0.0,
                               {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4});
  m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
                                                  4, 0, 0, 0, 0, 0}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(PadV2OpTest, SimpleDynamicValuedTest) {
  PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 5,
                               {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4});
  m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
                                                  4, 5, 5, 5, 5, 5}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(PadV2OpTest, AdvancedConstTest) {
  PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
                             {0, 0, 0, 2, 1, 3, 0, 0}, 0, {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4, 5, 6});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
                                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

TEST(PadV2OpTest, AdvancedDynamicTest) {
  PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, 0,
                               {TensorType_FLOAT32});
  m.SetInput({1, 2, 3, 4, 5, 6});
  m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
                                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

std::vector<testing::Matcher<float>> DequantizedArrayNear(
    const std::vector<float>& values, const float min, const float max) {
  const float quantization_tolerance = (max - min) / 255.0;
  return ArrayFloatNear(values, quantization_tolerance);
}

template <typename integer_type, TensorType tensor_dtype>
void SimpleConstTestV2() {
  // Padding is represented as four 2-D lists representing above padding and
  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
  PadV2OpConstModel<integer_type> m(
      {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
      {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
  m.template SetQuantizedPadValue<integer_type>(0);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8SimpleConstTest) {
  SimpleConstTestV2<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8SimpleConstTest) {
  SimpleConstTestV2<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void SimpleDynamicTestV2() {
  PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
                                      {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
                                      {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
  m.template SetQuantizedPadValue<integer_type>(0);
  m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicTest) {
  SimpleDynamicTestV2<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8SimpleDynamicTest) {
  SimpleDynamicTestV2<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void AdvancedConstTestV2() {
  PadV2OpConstModel<integer_type> m(
      {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
      {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
  m.template SetQuantizedPadValue<integer_type>(0);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
                   0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8AdvancedConstTest) {
  AdvancedConstTestV2<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8AdvancedConstTest) {
  AdvancedConstTestV2<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void AdvancedDynamicTestV2() {
  PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
                                      {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
                                      {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
  m.template SetQuantizedPadValue<integer_type>(0);
  m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
                   0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicTest) {
  AdvancedDynamicTestV2<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicTest) {
  AdvancedDynamicTestV2<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void SimpleConstValuedTest() {
  // Padding is represented as four 2-D lists representing above padding and
  // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
  PadV2OpConstModel<integer_type> m(
      {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
      {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
  m.template SetQuantizedPadValue<integer_type>(-0.5);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
                   0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8SimpleConstValuedTest) {
  SimpleConstValuedTest<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8SimpleConstValuedTest) {
  SimpleConstValuedTest<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void SimpleDynamicValuedTest() {
  PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
                                      {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
                                      {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
  m.template SetQuantizedPadValue<integer_type>(-0.5);
  m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
                   0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicValuedTest) {
  SimpleDynamicValuedTest<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8SimpleDynamicValuedTest) {
  SimpleDynamicValuedTest<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void AdvancedConstValuedTest() {
  PadV2OpConstModel<integer_type> m(
      {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
      {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
  m.template SetQuantizedPadValue<integer_type>(-0.5);
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
                   -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
                   -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8AdvancedConstValuedTest) {
  AdvancedConstValuedTest<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8AdvancedConstValuedTest) {
  AdvancedConstValuedTest<int8_t, TensorType_INT8>();
}

template <typename integer_type, TensorType tensor_dtype>
void AdvancedDynamicValuedTest() {
  PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
                                      {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
                                      {tensor_dtype, {}, -1.0, 1.0});
  m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
  m.template SetQuantizedPadValue<integer_type>(-0.5);
  m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
              ElementsAreArray(DequantizedArrayNear(
                  {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
                   -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
                   -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
                  -1.0, 1.0)));
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}

TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicValuedTest) {
  AdvancedDynamicValuedTest<uint8_t, TensorType_UINT8>();
}
TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicValuedTest) {
  AdvancedDynamicValuedTest<int8_t, TensorType_INT8>();
}

// A base class of Leaky ReLU op model. It provides the constructor for
// FloatLeakyReluOpModel and QuantizedLeakyReluOpModel.
class LeakyReluOpModel : public SingleOpModelWithNNAPI {
 public:
  LeakyReluOpModel(const TensorData& input, const float alpha)
      : input_type_(input.type) {
    input_ = AddInput(input);
    output_ = AddOutput({input.type, input.shape, input.min, input.max});

    SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
                 CreateLeakyReluOptions(builder_, alpha).Union());
    BuildInterpreterWithNNAPI({GetShape(input_)});
  }

  void SetInput(std::initializer_list<float> data) {
    SetData(input_, input_type_, data);
  }

  std::vector<float> GetOutput() {
    std::vector<float> output;
    GetData(output_, input_type_, &output);
    return output;
  }

 protected:
  int input_;
  int output_;

  const TensorType input_type_;
};

TEST(NNAPIDelegate, LeakyReluFloat) {
  LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);

  m.SetInput({
      0.0f, 1.0f, 3.0f,    // Row 1
      1.0f, -1.0f, -2.0f,  // Row 2
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
                                 0.0f, 1.0f, 3.0f,    // Row 1
                                 1.0f, -0.5f, -1.0f,  // Row 2

                             }));
}

TEST(NNAPIDelegate, LeakyReluQuantized) {
  const float kMin = -1;
  const float kMax = 127.f / 128.f;
  LeakyReluOpModel m({TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5f);
  m.SetInput({
      0.0f, 1.0f, 3.0f,    // Row 1
      1.0f, -1.0f, -2.0f,  // Row 2
  });
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
                                 {
                                     0.0f, 1.0f, 3.0f,    // Row 1
                                     1.0f, -0.5f, -1.0f,  // Row 2
                                 },
                                 kQuantizedTolerance)));
}
}  // namespace

namespace ops {
namespace builtin {
TfLiteRegistration* Register_FLOOR();
}  // namespace builtin
}  // namespace ops

namespace {

std::vector<uint32_t> GetNNAPIDimensions(const TfLiteTensor* tensor) {
  std::vector<uint32_t> dimensions;
  dimensions.reserve(tensor->dims->size);
  if (tensor->dims_signature != nullptr &&
      tensor->dims_signature->size == tensor->dims->size) {
    for (auto d : TfLiteIntArrayView(tensor->dims_signature)) {
      uint32_t nnapi_dim = (d == -1) ? 0 : static_cast<uint32_t>(d);
      dimensions.push_back(nnapi_dim);
    }
  } else {
    dimensions.assign(tensor->dims->data,
                      tensor->dims->data + tensor->dims->size);
  }
  return dimensions;
}

// The "nnapi-custom-op" is just float32 floor.
static const char kTestCustomOp[] = "nnapi-custom-op";
class NnapiTestVendorPlugin : public NnapiDelegateVendorPlugin {
 public:
  NnapiTestVendorPlugin() {
    ValidateNode = DoValidateNode;
    MapNode = DoMapNode;
    ConfigureCompilationHints = DoConfigureCompilationHints;
    ConfigureExecutionHints = DoConfigureExecutionHints;
  }

  static bool DoValidateNode(const TfLiteContext* context,
                             const TfLiteRegistration* registration,
                             const TfLiteNode* node) {
    if (strcmp(kTestCustomOp, registration->custom_name) != 0) {
      return false;
    }
    if (node->inputs->size != 1 || node->outputs->size != 1) {
      return false;
    }
    if (context->tensors[node->inputs->data[(0)]].type != kTfLiteFloat32 ||
        context->tensors[node->outputs->data[(0)]].type != kTfLiteFloat32) {
      return false;
    }
    return true;
  }

  static TfLiteStatus AddFloat32Tensor(const TfLiteContext* context,
                                       int tensor_index,
                                       NnapiMappingUtilCInterface* mapping,
                                       std::vector<uint32_t>* indices,
                                       ANeuralNetworksModel* model) {
    int ann_tensor_index = mapping->TfLiteIndexToNnIndex(mapping, tensor_index);
    if (ann_tensor_index != -1) {
      indices->push_back(ann_tensor_index);
      return kTfLiteOk;
    }
    // Allocate a new tensor index
    ann_tensor_index = mapping->AddNewNnTensorIndex(mapping, tensor_index);
    TfLiteTensor* tensor = &context->tensors[tensor_index];
    auto dimensions = GetNNAPIDimensions(tensor);
    ANeuralNetworksOperandType operand_type{
        .type = ANEURALNETWORKS_TENSOR_FLOAT32,
        .dimensionCount = static_cast<uint32_t>(dimensions.size()),
        .dimensions = dimensions.data(),
        .scale = 0.0f,
        .zeroPoint = 0,
    };
    EXPECT_EQ(NnApiImplementation()->ANeuralNetworksModel_addOperand(
                  model, &operand_type),
              ANEURALNETWORKS_NO_ERROR);
    if (tensor->allocation_type == kTfLiteMmapRo) {
      EXPECT_EQ(NnApiImplementation()->ANeuralNetworksModel_setOperandValue(
                    model, ann_tensor_index, tensor->data.data, tensor->bytes),
                ANEURALNETWORKS_NO_ERROR);
    }
    indices->push_back(ann_tensor_index);
    return kTfLiteOk;
  }

  static TfLiteStatus DoMapNode(TfLiteContext* context, const TfLiteNode* node,
                                int node_index,
                                NnapiMappingUtilCInterface* mapping,
                                ANeuralNetworksModel* model) {
    std::vector<uint32_t> input_indices;
    std::vector<uint32_t> output_indices;
    for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
      const auto input_index = node->inputs->data[input_pos];
      EXPECT_EQ(AddFloat32Tensor(context, input_index, mapping, &input_indices,
                                 model),
                kTfLiteOk);
    }
    for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) {
      const auto output_index = node->outputs->data[output_pos];
      EXPECT_EQ(AddFloat32Tensor(context, output_index, mapping,
                                 &output_indices, model),
                kTfLiteOk);
    }
    EXPECT_EQ(
        NnApiImplementation()->ANeuralNetworksModel_addOperation(
            model, ANEURALNETWORKS_FLOOR,
            static_cast<uint32_t>(input_indices.size()), input_indices.data(),
            static_cast<uint32_t>(output_indices.size()),
            output_indices.data()),
        ANEURALNETWORKS_NO_ERROR);
    mapping->AddNnapiToTfliteOpMapping(mapping, node_index);
    return kTfLiteOk;
  }

  static TfLiteStatus DoConfigureCompilationHints(
      const char* compilation_hints, ANeuralNetworksCompilation* compilation) {
    return kTfLiteOk;
  }

  static TfLiteStatus DoConfigureExecutionHints(
      const char* execution_hints, ANeuralNetworksExecution* execution) {
    return kTfLiteOk;
  }
};

class CustomFloorOpModel : public SingleOpModelWithNNAPI {
 public:
  CustomFloorOpModel(const StatefulNnApiDelegate::Options& options,
                     const TensorData& input, const TensorData& output,
                     bool allow_fp32_relax_to_fp16 = false,
                     bool apply_delegate = true)
      : SingleOpModelWithNNAPI(options) {
    Init(input, output, allow_fp32_relax_to_fp16, apply_delegate);
  }

  int input() { return input_; }
  int output() { return output_; }

  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }

 protected:
  int input_;
  int output_;

 private:
  // Performs initialization logic shared across all constructors.
  void Init(const TensorData& input, const TensorData& output,
            bool allow_fp32_relax_to_fp16 = false, bool apply_delegate = true) {
    input_ = AddInput(input);
    output_ = AddOutput(output);
    SetCustomOp(kTestCustomOp, {}, tflite::ops::builtin::Register_FLOOR);
    BuildInterpreterWithNNAPI({GetShape(input_)}, allow_fp32_relax_to_fp16,
                              apply_delegate);
  }
};

TEST(NNAPIDelegate, CustomFloorVendorExtension) {
  auto vendor_plugin = std::make_unique<NnapiTestVendorPlugin>();
  StatefulNnApiDelegate::Options options;
  options.accelerator_name = "nnapi-reference";
  options.vendor_plugin = vendor_plugin.get();

  CustomFloorOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
                       {TensorType_FLOAT32, {1, 2, 2, 1}});
  m.PopulateTensor<float>(m.input(), {0, 0.2, 1.7, 2.8});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.0, 1.0, 2.0}));
}

TEST(NNAPIDelegate, CustomFloorVendorExtensionDynamic) {
  // Skip the test until b/243704946 is fixed.
  GTEST_SKIP();
  // Models with dynamic dimensions and vendor plugin is not supported before
  // NNAPI 1.2 (API level 29).
  if (NnApiImplementation()->android_sdk_version <
      delegate::nnapi::kMinSdkVersionForNNAPI12) {
    GTEST_SKIP();
  }

  auto vendor_plugin = std::make_unique<NnapiTestVendorPlugin>();
  StatefulNnApiDelegate::Options options;
  options.accelerator_name = "nnapi-reference";
  options.vendor_plugin = vendor_plugin.get();
  options.allow_dynamic_dimensions = true;

  // Both input and output tensors have dynamic batch.
  auto tensor_data = TensorData{TensorType_FLOAT32,
                                /*shape=*/{1, 2, 2, 1},
                                /*min=*/0.0f,
                                /*max=*/0.0f,
                                /*scale=*/0.0f,
                                /*zero_point=*/0,
                                /*per_channel_quantization=*/false,
                                /*per_channel_quantization_scales=*/{},
                                /*per_channel_quantization_offsets=*/{},
                                /*channel_index=*/0,
                                /*traversal_order=*/{},
                                /*format=*/{},
                                /*block_size=*/{},
                                /*block_map=*/{},
                                /*shape_signature=*/{-1, 2, 2, 1}};
  size_t max_batch_size = 2;
  size_t tensor_max_size = max_batch_size * 2 * 2 * 1 * sizeof(float);
  CustomFloorOpModel m(options, tensor_data, tensor_data,
                       /*allow_fp32_relax_to_fp16=*/false,
                       /*apply_delegate=*/false);
  m.SetTensorMaxSize(m.input(), tensor_max_size);
  m.SetTensorMaxSize(m.output(), tensor_max_size);
  m.ApplyNNAPIDelegate();

  // Try the max batch size.
  EXPECT_EQ(m.ResizeInputTensor(m.input(), {2, 2, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
  m.PopulateTensor<float>(m.input(), {0, 0.2, 1.7, 2.8, 3.4, 4.1, 5.9, 6.3});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(),
              ElementsAreArray({0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0}));

  // Try another batch size.
  EXPECT_EQ(m.ResizeInputTensor(m.input(), {1, 2, 2, 1}), kTfLiteOk);
  EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
  m.PopulateTensor<float>(m.input(), {1.7, 2.8, 3.4, 4.1});
  ASSERT_EQ(m.Invoke(), kTfLiteOk);
  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.0, 2.0, 3.0, 4.0}));
}

}  // namespace
}  // namespace tflite
