micro: copy operator LEAKY_RELU kernel from lite

This is a copy with minimal modification of the kernel and test for
operator LEAKY_RELU from tensorflow/lite/kernels.
Adaptations to micro and addition to the micro build to follow.

PR step 3 for issue #46161
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 59341c8..3119ac4 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -245,6 +245,10 @@
       return ParsePool(op, error_reporter, allocator, builtin_data);
     }
 
+    case BuiltinOperator_LEAKY_RELU: {
+      return ParseLeakyRelu(op, error_reporter, allocator, builtin_data);
+    }
+
     case BuiltinOperator_LESS: {
       return ParseLess(op, error_reporter, allocator, builtin_data);
     }
@@ -674,16 +678,6 @@
       *builtin_data = params.release();
       return kTfLiteOk;
     }
-    case BuiltinOperator_LEAKY_RELU: {
-      auto params = safe_allocator.Allocate<TfLiteLeakyReluParams>();
-      TF_LITE_ENSURE(error_reporter, params != nullptr);
-      if (const auto* leaky_relu_params =
-              op->builtin_options_as_LeakyReluOptions()) {
-        params->alpha = leaky_relu_params->alpha();
-      }
-      *builtin_data = params.release();
-      return kTfLiteOk;
-    }
     case BuiltinOperator_MIRROR_PAD: {
       auto params = safe_allocator.Allocate<TfLiteMirrorPaddingParams>();
       TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -1247,6 +1241,22 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter,
+                            BuiltinDataAllocator* allocator,
+                            void** builtin_data) {
+  CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+  SafeBuiltinDataAllocator safe_allocator(allocator);
+  auto params = safe_allocator.Allocate<TfLiteLeakyReluParams>();
+  TF_LITE_ENSURE(error_reporter, params != nullptr);
+  if (const auto* leaky_relu_params =
+          op->builtin_options_as_LeakyReluOptions()) {
+    params->alpha = leaky_relu_params->alpha();
+  }
+  *builtin_data = params.release();
+  return kTfLiteOk;
+}
+
 // We have this parse function instead of directly returning kTfLiteOk from the
 // switch-case in ParseOpData because this function is used as part of the
 // selective registration for the OpResolver implementation in micro.
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 8b4a026..375d183 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -148,6 +148,10 @@
                                   BuiltinDataAllocator* allocator,
                                   void** builtin_data);
 
+TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter,
+                            BuiltinDataAllocator* allocator,
+                            void** builtin_data);
+
 TfLiteStatus ParseLess(const Operator* op, ErrorReporter* error_reporter,
                        BuiltinDataAllocator* allocator, void** builtin_data);
 
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 728d386..df4234f 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -480,6 +480,7 @@
         "reference/integer_ops/tanh.h",
         "reference/integer_ops/transpose_conv.h",
         "reference/l2normalization.h",
+        "reference/leaky_relu.h",
         "reference/logistic.h",
         "reference/maximum_minimum.h",
         "reference/mul.h",
@@ -576,6 +577,7 @@
         "reference/fully_connected.h",
         "reference/hard_swish.h",
         "reference/l2normalization.h",
+        "reference/leaky_relu.h",
         "reference/legacy_reference_ops.h",
         "reference/logistic.h",
         "reference/maximum_minimum.h",
diff --git a/tensorflow/lite/kernels/internal/reference/leaky_relu.h b/tensorflow/lite/kernels/internal/reference/leaky_relu.h
new file mode 100644
index 0000000..beedddc
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/leaky_relu.h
@@ -0,0 +1,66 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
+
+#include <algorithm>
+#include <limits>
+
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void LeakyRelu(const tflite::LeakyReluParams& params,
+                      const RuntimeShape& input_shape, const float* input_data,
+                      const RuntimeShape& output_shape, float* output_data) {
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  for (int i = 0; i < flat_size; ++i) {
+    const float val = input_data[i];
+    // Note that alpha might be > 1 or < 0, so we don't use std::max here.
+    output_data[i] = val > 0 ? val : val * params.alpha;
+  }
+}
+
+template <typename T>
+inline void QuantizeLeakyRelu(const LeakyReluParams& params,
+                              const RuntimeShape& input_shape,
+                              const T* input_data,
+                              const RuntimeShape& output_shape,
+                              T* output_data) {
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  static const int32_t quantized_min = std::numeric_limits<T>::min();
+  static const int32_t quantized_max = std::numeric_limits<T>::max();
+  for (int i = 0; i < flat_size; ++i) {
+    const int32_t input_value = input_data[i] - params.input_offset;
+    int32_t unclamped_output;
+    if (input_value >= 0) {
+      unclamped_output = params.output_offset +
+                         MultiplyByQuantizedMultiplier(
+                             input_value, params.output_multiplier_identity,
+                             params.output_shift_identity);
+    } else {
+      unclamped_output = params.output_offset +
+                         MultiplyByQuantizedMultiplier(
+                             input_value, params.output_multiplier_alpha,
+                             params.output_shift_alpha);
+    }
+    const T clamped_output =
+        std::min(quantized_max, std::max(quantized_min, unclamped_output));
+    output_data[i] = static_cast<T>(clamped_output);
+  }
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEAKY_RELU_H_
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index a7270e8..a4fe98b 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -48,6 +48,7 @@
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
+#include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
 #include "tensorflow/lite/kernels/internal/reference/mul.h"
@@ -211,48 +212,6 @@
   }
 }
 
-inline void LeakyRelu(const tflite::LeakyReluParams& params,
-                      const RuntimeShape& input_shape, const float* input_data,
-                      const RuntimeShape& output_shape, float* output_data) {
-  ruy::profiler::ScopeLabel label("LeakyRelu (not fused)");
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-  for (int i = 0; i < flat_size; ++i) {
-    const float val = input_data[i];
-    // Note that alpha might be > 1 or < 0, so we don't use std::max here.
-    output_data[i] = val > 0 ? val : val * params.alpha;
-  }
-}
-
-template <typename T>
-inline void QuantizeLeakyRelu(const LeakyReluParams& params,
-                              const RuntimeShape& input_shape,
-                              const T* input_data,
-                              const RuntimeShape& output_shape,
-                              T* output_data) {
-  ruy::profiler::ScopeLabel label("Quantized LeakyRelu (not fused)");
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-  static const int32 quantized_min = std::numeric_limits<T>::min();
-  static const int32 quantized_max = std::numeric_limits<T>::max();
-  for (int i = 0; i < flat_size; ++i) {
-    const int32 input_value = input_data[i] - params.input_offset;
-    int32 unclamped_output;
-    if (input_value >= 0) {
-      unclamped_output = params.output_offset +
-                         MultiplyByQuantizedMultiplier(
-                             input_value, params.output_multiplier_identity,
-                             params.output_shift_identity);
-    } else {
-      unclamped_output = params.output_offset +
-                         MultiplyByQuantizedMultiplier(
-                             input_value, params.output_multiplier_alpha,
-                             params.output_shift_alpha);
-    }
-    const T clamped_output =
-        std::min(quantized_max, std::max(quantized_min, unclamped_output));
-    output_data[i] = static_cast<T>(clamped_output);
-  }
-}
-
 // T is expected to be either float or int.
 template <typename T>
 inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
diff --git a/tensorflow/lite/micro/kernels/leaky_relu.cc b/tensorflow/lite/micro/kernels/leaky_relu.cc
new file mode 100644
index 0000000..b66e80e
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/leaky_relu.cc
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stddef.h>
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <functional>
+#include <limits>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/cpu_backend_context.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/cppmath.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
+#include "tensorflow/lite/kernels/internal/reference/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/prelu.h"
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/lite/kernels/internal/reference/tanh.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace activations {
+namespace {
+
+// OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
+// of the activation ops below.
+
+enum KernelType {
+  kReference,
+  kGenericOptimized,
+  kFixedPointOptimized,
+};
+
+struct OpData {
+  int32_t input_multiplier = 0;
+  int input_left_shift = 0;
+  int32_t input_range_radius = 0;
+  int diff_min = 0;
+  uint8_t table[256] = {0};
+};
+
+struct LeakyReluOpData : public OpData {
+  int32_t output_multiplier_alpha = 0;
+  int32_t output_shift_alpha = 0;
+  int32_t output_multiplier_identity = 0;
+  int32_t output_shift_identity = 0;
+};
+
+template <typename T>
+void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
+                       const LeakyReluOpData* data) {
+  LeakyReluParams op_params;
+
+  op_params.input_offset = input->params.zero_point;
+  op_params.output_offset = output->params.zero_point;
+  op_params.output_multiplier_alpha = data->output_multiplier_alpha;
+  op_params.output_shift_alpha = data->output_shift_alpha;
+  op_params.output_multiplier_identity = data->output_multiplier_identity;
+  op_params.output_shift_identity = data->output_shift_identity;
+  reference_ops::QuantizeLeakyRelu(
+      op_params, GetTensorShape(input), GetTensorData<T>(input),
+      GetTensorShape(output), GetTensorData<T>(output));
+}
+
+}  // namespace
+
+void* LeakyReluInit(TfLiteContext* context, const char* buffer, size_t length) {
+  return new LeakyReluOpData;
+}
+
+void LeakyReluFree(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<LeakyReluOpData*>(buffer);
+}
+
+TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  const TfLiteTensor* input;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+
+  LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
+
+  if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
+      output->type == kTfLiteInt16) {
+    const auto* params =
+        reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
+
+    double alpha_multiplier =
+        input->params.scale * params->alpha / output->params.scale;
+    QuantizeMultiplier(alpha_multiplier, &data->output_multiplier_alpha,
+                       &data->output_shift_alpha);
+    double identity_multiplier = input->params.scale / output->params.scale;
+    QuantizeMultiplier(identity_multiplier, &data->output_multiplier_identity,
+                       &data->output_shift_identity);
+  }
+
+  if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
+    TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+  }
+
+  return context->ResizeTensor(context, output,
+                               TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+  const auto* params =
+      reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
+  const LeakyReluOpData* data =
+      reinterpret_cast<LeakyReluOpData*>(node->user_data);
+
+  LeakyReluParams op_params;
+  switch (input->type) {
+    case kTfLiteFloat32: {
+      op_params.alpha = params->alpha;
+      optimized_ops::LeakyRelu(
+          op_params, GetTensorShape(input), GetTensorData<float>(input),
+          GetTensorShape(output), GetTensorData<float>(output));
+      return kTfLiteOk;
+    } break;
+    case kTfLiteUInt8: {
+      QuantizeLeakyRelu<uint8_t>(input, output, data);
+      return kTfLiteOk;
+    } break;
+    case kTfLiteInt8: {
+      QuantizeLeakyRelu<int8_t>(input, output, data);
+      return kTfLiteOk;
+    } break;
+    case kTfLiteInt16: {
+      QuantizeLeakyRelu<int16_t>(input, output, data);
+      return kTfLiteOk;
+    } break;
+    default:
+      TF_LITE_KERNEL_LOG(
+          context,
+          "Only float32, int8, int16 and uint8 is supported currently, got %s.",
+          TfLiteTypeGetName(input->type));
+      return kTfLiteError;
+  }
+}
+
+}  // namespace activations
+
+TfLiteRegistration* Register_LEAKY_RELU() {
+  static TfLiteRegistration r = {
+      activations::LeakyReluInit, activations::LeakyReluFree,
+      activations::LeakyReluPrepare, activations::LeakyReluEval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/leaky_relu_test.cc b/tensorflow/lite/micro/kernels/leaky_relu_test.cc
new file mode 100644
index 0000000..4314801
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/leaky_relu_test.cc
@@ -0,0 +1,211 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include <math.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <algorithm>
+#include <initializer_list>
+#include <limits>
+#include <map>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/string_type.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+  // A dedicated constructor for LeakyRelu, which does some options.
+  BaseActivationsOpModel(TensorData input, float alpha) {
+    input_ = AddInput(input);
+    // The output scale and input scale might be different.
+    if (input.type == TensorType_UINT8 || input.type == TensorType_INT8 ||
+        input.type == TensorType_INT16) {
+      auto output_min = (input.min >= 0) ? input.min : input.min * alpha;
+      auto output_max = (input.max >= 0) ? input.max : input.max * alpha;
+      if (input.type == TensorType_INT16) {
+        output_ = AddOutput({TensorType_INT16,
+                             {},
+                             0,
+                             0,
+                             output_max / (std::numeric_limits<int16_t>::max()),
+                             0});
+      } else {
+        output_ = AddOutput({input.type, {}, output_min, output_max});
+      }
+    } else {
+      output_ = AddOutput({input.type, {}});
+    }
+    SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
+                 CreateLeakyReluOptions(builder_, alpha).Union());
+    BuildInterpreter({GetShape(input_)});
+  }
+
+ protected:
+  int input_;
+  int output_;
+};
+
+// Our fixed-point math function implementations have roughly 12 bits of
+// accuracy, when specialized to 16-bit fixed-point arithmetic.
+// That is purely an implementation compromise, it would have been possible
+// to get closer to 16 bits of accuracy but that would be more expensive,
+// and not needed for our purposes as ultimately the output is either
+// immediately down-quantized to 8 bits, or will typically be at the output
+// of the surrounding LSTM cell.
+// So we can require roughly 2^-12 accuracy when the output is 16-bit, and
+// we can more or less expect the full 2^-8 accuracy when the output is 8-bit.
+//
+// However, the representable output interval is often [-1, 1]  (it has to be
+// for tanh, and even for logistic, when we implement it in fixed-point, we
+// typically have to do so on such a symmetric interval, e.g. ARM NEON only
+// has signed fixed-point arithmetic (SQRDMULH)).  As the width of [-1, 1]
+// is 2, our representable values are often diluted by a factor of 2, whence
+// the factor of 2 below.
+const float kQuantizedTolerance = 2 * (1. / 256);
+const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
+
+class QuantizedActivationsOpModel : public BaseActivationsOpModel {
+ public:
+  using BaseActivationsOpModel::BaseActivationsOpModel;
+
+  template <typename T>
+  void SetInput(const std::vector<float>& data) {
+    QuantizeAndPopulate<T>(input_, data);
+  }
+  template <typename T>
+  std::vector<T> GetOutput() {
+    return ExtractVector<T>(output_);
+  }
+
+  template <typename T>
+  std::vector<float> GetDequantizedOutput() {
+    return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+                         GetZeroPoint(output_));
+  }
+};
+
+TEST(QuantizedActivationsOpTest, LeakyReluUint8) {
+  const float kMin = -1;
+  const float kMax = 127.f / 128.f;
+  QuantizedActivationsOpModel m(
+      /*input=*/{TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5);
+
+  m.SetInput<uint8_t>({
+      0.0f, 1.0f, 3.0f,    // Row 1
+      1.0f, -1.0f, -2.0f,  // Row 2
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+              ElementsAreArray(ArrayFloatNear(
+                  {
+                      0.0f, 1.0f, 3.0f,    // Row 1
+                      1.0f, -0.5f, -1.0f,  // Row 2
+                  },
+                  kQuantizedTolerance * 8)));
+}
+
+template <TensorType tensor_type, typename integer_dtype>
+void QuantizedActivationsOpTestLeakyRelu() {
+  const float kMin = -1;
+  const float kMax =
+      std::numeric_limits<integer_dtype>::max() /
+      static_cast<float>(std::numeric_limits<integer_dtype>::max() + 1);
+
+  QuantizedActivationsOpModel m(
+      /*input=*/{tensor_type, {5, 5}, 5 * kMin, 5 * kMax}, 0.1);
+
+  m.SetInput<integer_dtype>({
+      -5.0f, -4.6f, -4.2f, -3.8f, -3.4f,  // Row 1
+      -3.0f, -2.6f, -2.2f, -1.8f, -1.4f,  // Row 2
+      -1.0f, -0.6f, -0.2f, 0.2f,  0.6f,   // Row 3
+      1.0f,  1.4f,  1.8f,  2.2f,  2.6f,   // Row 4
+      3.0f,  3.4f,  3.8f,  4.2f,  4.6f,   // Row 5
+  });
+  m.Invoke();
+
+  float kTestQuantizedTolerance = tensor_type == TensorType_INT16
+                                      ? kQuantizedToleranceInt16
+                                      : kQuantizedTolerance * 5;
+
+  EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
+              ElementsAreArray(ArrayFloatNear(
+                  {
+                      -0.50f, -0.46f, -0.42f, -0.38f, -0.34f,  // Row 1
+                      -0.30f, -0.26f, -0.22f, -0.18f, -0.14f,  // Row 2
+                      -0.10f, -0.06f, -0.02f, 0.20f,  0.60f,   // Row 3
+                      1.00f,  1.40f,  1.80f,  2.20f,  2.60f,   // Row 4
+                      3.00f,  3.40f,  3.80f,  4.20f,  4.60f,   // Row 5
+                  },
+                  kTestQuantizedTolerance)));
+}
+
+TEST(QuantizedActivationsOpTest, LeakyReluInt8) {
+  QuantizedActivationsOpTestLeakyRelu<TensorType_INT8, int8_t>();
+}
+
+TEST(QuantizedActivationsOpTest, LeakyReluInt16) {
+  QuantizedActivationsOpTestLeakyRelu<TensorType_INT16, int16_t>();
+}
+
+class LeakyReluOpModel : public SingleOpModel {
+ public:
+  LeakyReluOpModel(const TensorData& input, float alpha) {
+    input_ = AddInput(input);
+    output_ = AddOutput(input);
+    SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
+                 CreateLeakyReluOptions(builder_, alpha).Union());
+    BuildInterpreter({GetShape(input_)});
+  }
+  void SetInput(std::initializer_list<float> data) {
+    PopulateTensor(input_, data);
+  }
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+  int input_;
+  int output_;
+};
+
+TEST(FloatActivationsOpTest, LeakyRelu) {
+  LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
+
+  m.SetInput({
+      0.0f, 1.0f, 3.0f,    // Row 1
+      1.0f, -1.0f, -2.0f,  // Row 2
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+                                 0.0f, 1.0f, 3.0f,    // Row 1
+                                 1.0f, -0.5f, -1.0f,  // Row 2
+                             }));
+}
+
+}  // namespace
+}  // namespace tflite