Automated rollback of commit 6ded9aa2ba47fdad13487b822d238d3a9663ef03

PiperOrigin-RevId: 267278208
diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD
index 83826c6..10cf1b9 100644
--- a/tensorflow/lite/experimental/micro/kernels/BUILD
+++ b/tensorflow/lite/experimental/micro/kernels/BUILD
@@ -20,7 +20,6 @@
         "comparisons.cc",
         "conv.cc",
         "depthwise_conv.cc",
-        "dequantize.cc",
         "elementwise.cc",
         "floor.cc",
         "fully_connected.cc",
@@ -30,7 +29,6 @@
         "pack.cc",
         "pooling.cc",
         "prelu.cc",
-        "quantize.cc",
         "reshape.cc",
         "round.cc",
         "softmax.cc",
@@ -79,7 +77,6 @@
         "ceil.cc",
         "comparisons.cc",
         "conv.cc",
-        "dequantize.cc",
         "elementwise.cc",
         "floor.cc",
         "fully_connected.cc",
@@ -90,7 +87,6 @@
         "pooling.cc",
         "portable_optimized/depthwise_conv.cc",
         "prelu.cc",
-        "quantize.cc",
         "reshape.cc",
         "round.cc",
         "softmax.cc",
@@ -427,34 +423,6 @@
     deps = ["//tensorflow/lite/c:c_api_internal"],
 )
 
-tflite_micro_cc_test(
-    name = "quantize_test",
-    srcs = [
-        "quantize_test.cc",
-    ],
-    deps = [
-        ":all_ops_resolver",
-        "//tensorflow/lite/c:c_api_internal",
-        "//tensorflow/lite/experimental/micro:micro_framework",
-        "//tensorflow/lite/experimental/micro/kernels:micro_utils",
-        "//tensorflow/lite/experimental/micro/testing:micro_test",
-    ],
-)
-
-tflite_micro_cc_test(
-    name = "dequantize_test",
-    srcs = [
-        "dequantize_test.cc",
-    ],
-    deps = [
-        ":all_ops_resolver",
-        "//tensorflow/lite/c:c_api_internal",
-        "//tensorflow/lite/experimental/micro:micro_framework",
-        "//tensorflow/lite/experimental/micro/kernels:micro_utils",
-        "//tensorflow/lite/experimental/micro/testing:micro_test",
-    ],
-)
-
 cc_library(
     name = "micro_utils",
     hdrs = ["micro_utils.h"],
diff --git a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
index 55e5f44..ddbd114 100644
--- a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
+++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
@@ -54,8 +54,7 @@
 TfLiteRegistration* Register_UNPACK();
 TfLiteRegistration* Register_NEG();
 TfLiteRegistration* Register_ADD();
-TfLiteRegistration* Register_QUANTIZE();
-TfLiteRegistration* Register_DEQUANTIZE();
+
 AllOpsResolver::AllOpsResolver() {
   AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
   AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
@@ -99,12 +98,6 @@
   AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
   AddBuiltin(BuiltinOperator_NEG, Register_NEG());
   AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-  AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
-             /* min_version */ 1,
-             /* max_version */ 4);
-  AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
-             /* min_version */ 1,
-             /* max_version */ 4);
 }
 
 }  // namespace micro
diff --git a/tensorflow/lite/experimental/micro/kernels/dequantize.cc b/tensorflow/lite/experimental/micro/kernels/dequantize.cc
deleted file mode 100644
index d4861b6..0000000
--- a/tensorflow/lite/experimental/micro/kernels/dequantize.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace dequantize {
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
-  // TODO(b/140515557): Add cached dequant to improve hybrid model performance.
-  TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
-  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
-
-  TF_LITE_ENSURE(context,
-                 input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
-  TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
-  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
-
-  tflite::DequantizationParams op_params;
-  op_params.zero_point = input->params.zero_point;
-  op_params.scale = input->params.scale;
-  switch (input->type) {
-    case kTfLiteUInt8:
-      reference_ops::Dequantize(
-          op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
-          GetTensorShape(output), GetTensorData<float>(output));
-      break;
-    case kTfLiteInt8:
-      reference_ops::Dequantize(
-          op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
-          GetTensorShape(output), GetTensorData<float>(output));
-      break;
-    default:
-      context->ReportError(context, "Type %s (%d) not supported.",
-                           TfLiteTypeGetName(input->type), input->type);
-      return kTfLiteError;
-  }
-
-  return kTfLiteOk;
-}
-
-}  // namespace dequantize
-
-TfLiteRegistration* Register_DEQUANTIZE() {
-  static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare,
-                                 dequantize::Eval};
-  return &r;
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/dequantize_test.cc b/tensorflow/lite/experimental/micro/kernels/dequantize_test.cc
deleted file mode 100644
index 4b0476e..0000000
--- a/tensorflow/lite/experimental/micro/kernels/dequantize_test.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
-#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
-#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
-
-namespace tflite {
-namespace testing {
-namespace {
-
-template <typename T>
-void TestDequantize(std::initializer_list<int> input_dims_data,
-                    std::initializer_list<T> input_data,
-                    std::initializer_list<int> output_dims_data,
-                    std::initializer_list<float> expected_output_data,
-                    float min, float max, float* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
-  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
-  const int output_dims_count = ElementCount(*output_dims);
-
-  ::tflite::ops::micro::AllOpsResolver resolver;
-
-  float scale = ScaleFromMinMax<T>(min, max);
-  int32_t zero_point = ZeroPointFromMinMax<T>(min, max);
-
-  // TFLite float array takes int size followed by a variable size float array.
-  struct {
-    TfLiteFloatArray arr;
-    float data[1];
-  } scale_array = {{1}, {scale}};
-
-  TfLiteAffineQuantization builtin_data = {
-      .scale = reinterpret_cast<TfLiteFloatArray*>(&scale_array),
-      .zero_point = IntArrayFromInitializer({1, zero_point}),
-  };
-
-  TfLiteTensor output_tensor =
-      CreateFloatTensor(output_data, output_dims, "output_tensor");
-  output_tensor.quantization.type = kTfLiteAffineQuantization;
-  output_tensor.quantization.params = &builtin_data;
-
-  // 1 input, 1 output.
-  constexpr int tensors_size = 2;
-  TfLiteTensor tensors[tensors_size] = {
-      CreateQuantizedTensor(input_data, input_dims, "input_tensor", min, max),
-      output_tensor,
-  };
-
-  TfLiteContext context;
-  PopulateContext(tensors, tensors_size, &context);
-
-  // Version 4 ops support int8 quantization.
-  const TfLiteRegistration* registration =
-      resolver.FindOp(tflite::BuiltinOperator_DEQUANTIZE, 4);
-
-  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
-
-  const char* init_data = reinterpret_cast<const char*>(&builtin_data);
-  size_t init_data_size = 0;
-  void* user_data = nullptr;
-  if (registration->init) {
-    user_data = registration->init(&context, init_data, init_data_size);
-  }
-
-  int inputs_array_data[] = {1, 0};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 1};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
-  int temporaries_array_data[] = {0};
-  TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
-
-  TfLiteNode node;
-  node.inputs = inputs_array;
-  node.outputs = outputs_array;
-  node.temporaries = temporaries_array;
-  node.user_data = user_data;
-  node.builtin_data = reinterpret_cast<void*>(&builtin_data);
-  node.custom_initial_data = nullptr;
-  node.custom_initial_data_size = 0;
-  node.delegate = nullptr;
-
-  if (registration->prepare) {
-    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
-  }
-  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
-
-  if (registration->free) {
-    registration->free(&context, user_data);
-  }
-
-  for (int i = 0; i < output_dims_count; ++i) {
-    TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
-                              0.001);
-  }
-}
-
-}  // namespace
-}  // namespace testing
-}  // namespace tflite
-
-TF_LITE_MICRO_TESTS_BEGIN
-
-TF_LITE_MICRO_TEST(DequantizeOpTestUint8) {
-  // [-63.5, 64] -> scale=0.5, zero_point=127 for UINT8
-  float output[10];
-  tflite::testing::TestDequantize(
-      {2, 5, 2},
-      std::initializer_list<uint8_t>{0, 1, 2, 3, 4, 251, 252, 253, 254, 255},
-      {2, 5, 2}, {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}, -63.5,
-      64, output);
-}
-
-TF_LITE_MICRO_TEST(DequantizeOpTestInt8) {
-  // [-63.5, 64] -> scale=0.5, zero_point=-1 for INT8
-  float output[10];
-  tflite::testing::TestDequantize(
-      {2, 5, 2},
-      std::initializer_list<int8_t>{-128, -127, -126, -125, -124, 123, 124, 125,
-                                    126, 127},
-      {2, 5, 2}, {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}, -63.5,
-      64, output);
-}
-
-TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/kernels/quantize.cc b/tensorflow/lite/experimental/micro/kernels/quantize.cc
deleted file mode 100644
index 2da54ab..0000000
--- a/tensorflow/lite/experimental/micro/kernels/quantize.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/kernels/internal/reference/quantize.h"
-
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace quantize {
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  return nullptr;
-}
-
-void Free(TfLiteContext* context, void* buffer) {}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
-  TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
-  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
-
-  // TODO(b/128934713): Add support for fixed-point per-channel quantization.
-  // Currently this only support affine per-layer quantization.
-  TF_LITE_ENSURE_EQ(context, output->quantization.type,
-                    kTfLiteAffineQuantization);
-  const auto* affine_quantization =
-      reinterpret_cast<TfLiteAffineQuantization*>(output->quantization.params);
-  TF_LITE_ENSURE(context, affine_quantization);
-  TF_LITE_ENSURE(context, affine_quantization->scale);
-  TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
-
-  // TFLite micro currently supports
-  TF_LITE_ENSURE(context, input->type == kTfLiteFloat32);
-  TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
-                              output->type == kTfLiteInt8 ||
-                              output->type == kTfLiteInt16);
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
-  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
-
-  tflite::QuantizationParams op_params;
-  op_params.zero_point = output->params.zero_point;
-  op_params.scale = output->params.scale;
-  switch (output->type) {
-    case kTfLiteInt8:
-      reference_ops::AffineQuantize(
-          op_params, GetTensorShape(input), GetTensorData<float>(input),
-          GetTensorShape(output), GetTensorData<int8_t>(output));
-      break;
-    case kTfLiteUInt8:
-      reference_ops::AffineQuantize(
-          op_params, GetTensorShape(input), GetTensorData<float>(input),
-          GetTensorShape(output), GetTensorData<uint8_t>(output));
-      break;
-    default:
-      context->ReportError(context, "Output type %s (%d) not supported",
-                           TfLiteTypeGetName(input->type), output->type);
-      return kTfLiteError;
-  }
-
-  return kTfLiteOk;
-}
-
-}  // namespace quantize
-
-// This Op (QUANTIZE) quantizes the input and produces quantized output.
-// AffineQuantize takes scale and zero point and quantizes the float value to
-// quantized output, in int8 or uint8 format.
-TfLiteRegistration* Register_QUANTIZE() {
-  static TfLiteRegistration r = {quantize::Init, quantize::Free,
-                                 quantize::Prepare, quantize::Eval};
-  return &r;
-}
-
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/experimental/micro/kernels/quantize_test.cc b/tensorflow/lite/experimental/micro/kernels/quantize_test.cc
deleted file mode 100644
index b4f24ba..0000000
--- a/tensorflow/lite/experimental/micro/kernels/quantize_test.cc
+++ /dev/null
@@ -1,160 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
-#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
-#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
-
-namespace tflite {
-namespace testing {
-namespace {
-
-template <typename T>
-void TestQuantize(std::initializer_list<int> input_dims_data,
-                  std::initializer_list<float> input_data,
-                  std::initializer_list<int> output_dims_data,
-                  std::initializer_list<T> expected_output_data, float min,
-                  float max, T* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
-  TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
-  const int output_dims_count = ElementCount(*output_dims);
-
-  ::tflite::ops::micro::AllOpsResolver resolver;
-
-  float scale = ScaleFromMinMax<T>(min, max);
-  int32_t zero_point = ZeroPointFromMinMax<T>(min, max);
-
-  // TFLite float array takes int size followed by a variable size float array.
-  struct {
-    TfLiteFloatArray arr;
-    float data[1];
-  } scale_array = {{1}, {scale}};
-
-  TfLiteAffineQuantization builtin_data = {
-      .scale = reinterpret_cast<TfLiteFloatArray*>(&scale_array),
-      .zero_point = IntArrayFromInitializer({1, static_cast<int>(zero_point)}),
-  };
-
-  TfLiteTensor output_tensor = CreateQuantizedTensor(output_data, output_dims,
-                                                     "output_tensor", min, max);
-  output_tensor.quantization.type = kTfLiteAffineQuantization;
-  output_tensor.quantization.params = &builtin_data;
-
-  // 1 input, 1 output.
-  constexpr int tensors_size = 2;
-  TfLiteTensor tensors[tensors_size] = {
-      CreateFloatTensor(input_data, input_dims, "input_tensor"),
-      output_tensor,
-  };
-
-  TfLiteContext context;
-  PopulateContext(tensors, tensors_size, &context);
-
-  // Version 4 ops support int8 quantization.
-  const TfLiteRegistration* registration =
-      resolver.FindOp(tflite::BuiltinOperator_QUANTIZE, 4);
-
-  TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
-
-  const char* init_data = reinterpret_cast<const char*>(&builtin_data);
-  size_t init_data_size = 0;
-  void* user_data = nullptr;
-  if (registration->init) {
-    user_data = registration->init(&context, init_data, init_data_size);
-  }
-
-  int inputs_array_data[] = {1, 0};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 1};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
-  int temporaries_array_data[] = {0};
-  TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
-
-  TfLiteNode node;
-  node.inputs = inputs_array;
-  node.outputs = outputs_array;
-  node.temporaries = temporaries_array;
-  node.user_data = user_data;
-  node.builtin_data = reinterpret_cast<void*>(&builtin_data);
-  node.custom_initial_data = nullptr;
-  node.custom_initial_data_size = 0;
-  node.delegate = nullptr;
-
-  if (registration->prepare) {
-    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
-  }
-  TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
-
-  if (registration->free) {
-    registration->free(&context, user_data);
-  }
-
-  for (int i = 0; i < output_dims_count; ++i) {
-    TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
-  }
-}
-
-}  // namespace
-}  // namespace testing
-}  // namespace tflite
-
-TF_LITE_MICRO_TESTS_BEGIN
-
-TF_LITE_MICRO_TEST(QuantizeOpTestUint8) {
-  // [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
-  uint8_t output[10];
-  tflite::testing::TestQuantize(
-      {2, 2, 5}, {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64},
-      {2, 2, 5},
-      std::initializer_list<uint8_t>{0, 1, 2, 3, 4, 251, 252, 253, 254, 255},
-      -63.5, 64, output);
-}
-
-TF_LITE_MICRO_TEST(QuantizeOpTestUint8NoScale) {
-  // [-127, 128] -> scale=1.0 zero_point=128 for UINT8
-  uint8_t output[10];
-  tflite::testing::TestQuantize(
-      {2, 2, 5}, {-127, -126, -125, -124, -123, 124, 125, 126, 127, 128},
-      {2, 2, 5},
-      std::initializer_list<uint8_t>{0, 1, 2, 3, 4, 251, 252, 253, 254, 255},
-      -127, 128, output);
-}
-
-TF_LITE_MICRO_TEST(QuantizeOpTestInt8) {
-  // [-63.5, 64] -> scale=0.5, zero_point=-1 for INT8
-  int8_t output[10];
-  tflite::testing::TestQuantize(
-      {2, 2, 5}, {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64},
-      {2, 2, 5},
-      std::initializer_list<int8_t>{-128, -127, -126, -125, -124, 123, 124, 125,
-                                    126, 127},
-      -63.5, 64, output);
-}
-
-TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) {
-  // [-128, 127] -> scale=1.0, zero_point=0 for INT8
-  int8_t output[10];
-  tflite::testing::TestQuantize(
-      {2, 2, 5}, {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64},
-      {2, 2, 5},
-      std::initializer_list<int8_t>{-64, -63, -63, -62, -62, 62, 63, 63, 64,
-                                    64},
-      -128, 127, output);
-}
-
-TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile
index 62981e2..d2ed2e6 100644
--- a/tensorflow/lite/experimental/micro/tools/make/Makefile
+++ b/tensorflow/lite/experimental/micro/tools/make/Makefile
@@ -116,7 +116,6 @@
 tensorflow/lite/kernels/internal/reference/conv.h \
 tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
 tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
-tensorflow/lite/kernels/internal/reference/dequantize.h \
 tensorflow/lite/kernels/internal/reference/floor.h \
 tensorflow/lite/kernels/internal/reference/fully_connected.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/add.h \
@@ -126,7 +125,6 @@
 tensorflow/lite/kernels/internal/reference/pooling.h \
 tensorflow/lite/kernels/internal/reference/prelu.h \
 tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h \
-tensorflow/lite/kernels/internal/reference/quantize.h \
 tensorflow/lite/kernels/internal/reference/round.h \
 tensorflow/lite/kernels/internal/reference/softmax.h \
 tensorflow/lite/kernels/internal/reference/strided_slice.h \
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index e57f223..c52503e 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -366,7 +366,6 @@
         "reference/conv.h",
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
-        "reference/dequantize.h",
         "reference/floor.h",
         "reference/fully_connected.h",
         "reference/integer_ops/add.h",
@@ -388,7 +387,6 @@
         "reference/pooling.h",
         "reference/prelu.h",
         "reference/process_broadcast_shapes.h",
-        "reference/quantize.h",
         "reference/reference_ops.h",
         "reference/round.h",
         "reference/softmax.h",
@@ -434,7 +432,6 @@
         "reference/conv.h",
         "reference/depthwiseconv_float.h",
         "reference/depthwiseconv_uint8.h",
-        "reference/dequantize.h",
         "reference/floor.h",
         "reference/fully_connected.h",
         "reference/legacy_reference_ops.h",
@@ -443,7 +440,6 @@
         "reference/pooling.h",
         "reference/prelu.h",
         "reference/process_broadcast_shapes.h",
-        "reference/quantize.h",
         "reference/reference_ops.h",
         "reference/round.h",
         "reference/softmax.h",
diff --git a/tensorflow/lite/kernels/internal/reference/dequantize.h b/tensorflow/lite/kernels/internal/reference/dequantize.h
deleted file mode 100644
index 1040001..0000000
--- a/tensorflow/lite/kernels/internal/reference/dequantize.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEQUANTIZE_H_
-#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEQUANTIZE_H_
-
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/types.h"
-
-namespace tflite {
-
-namespace reference_ops {
-
-template <typename T>
-inline void Dequantize(const tflite::DequantizationParams& op_params,
-                       const RuntimeShape& input_shape, const T* input_data,
-                       const RuntimeShape& output_shape, float* output_data) {
-  int32 zero_point = op_params.zero_point;
-  const double scale = op_params.scale;
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-
-  for (int i = 0; i < flat_size; i++) {
-    const int32 val = input_data[i];
-    const float result = static_cast<float>(scale * (val - zero_point));
-    output_data[i] = result;
-  }
-}
-
-}  // namespace reference_ops
-
-}  // namespace tflite
-#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEQUANTIZE_H_
diff --git a/tensorflow/lite/kernels/internal/reference/quantize.h b/tensorflow/lite/kernels/internal/reference/quantize.h
deleted file mode 100644
index 37e2bea..0000000
--- a/tensorflow/lite/kernels/internal/reference/quantize.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_QUANTIZE_H_
-#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_QUANTIZE_H_
-
-#include "tensorflow/lite/kernels/internal/common.h"
-#include "tensorflow/lite/kernels/internal/round.h"
-#include "tensorflow/lite/kernels/internal/types.h"
-
-namespace tflite {
-
-namespace reference_ops {
-
-template <typename T>
-inline void AffineQuantize(const tflite::QuantizationParams& op_params,
-                           const RuntimeShape& input_shape,
-                           const float* input_data,
-                           const RuntimeShape& output_shape, T* output_data) {
-  const int32 zero_point = op_params.zero_point;
-  const double scale = static_cast<double>(op_params.scale);
-  const int flat_size = MatchingFlatSize(input_shape, output_shape);
-  static constexpr int32 min_val = std::numeric_limits<T>::min();
-  static constexpr int32 max_val = std::numeric_limits<T>::max();
-
-  for (int i = 0; i < flat_size; i++) {
-    const float val = input_data[i];
-    int32 unclamped = static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
-    int32 clamped = std::min(std::max(unclamped, min_val), max_val);
-    output_data[i] = clamped;
-  }
-}
-
-}  // namespace reference_ops
-
-}  // namespace tflite
-#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_QUANTIZE_H_
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 53c135b..5f2e833 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -38,7 +38,6 @@
 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
 #include "tensorflow/lite/kernels/internal/reference/conv.h"
-#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
 #include "tensorflow/lite/kernels/internal/reference/floor.h"
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
@@ -46,7 +45,6 @@
 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
-#include "tensorflow/lite/kernels/internal/reference/quantize.h"
 #include "tensorflow/lite/kernels/internal/reference/round.h"
 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
@@ -2029,6 +2027,21 @@
   }
 }
 
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+                       const RuntimeShape& input_shape, const uint8* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Dequantize");
+  int32 zero_point = op_params.zero_point;
+  double scale = op_params.scale;
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+
+  for (int i = 0; i < flat_size; i++) {
+    int32 val = input_data[i];
+    float result = static_cast<float>(scale * (val - zero_point));
+    output_data[i] = result;
+  }
+}
+
 inline void Dequantize(const RuntimeShape& input_shape,
                        const Eigen::half* input_data,
                        const RuntimeShape& output_shape, float* output_data) {
@@ -2038,6 +2051,26 @@
   }
 }
 
+template <typename T>
+inline void AffineQuantize(const tflite::QuantizationParams& op_params,
+                           const RuntimeShape& input_shape,
+                           const float* input_data,
+                           const RuntimeShape& output_shape, T* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Quantize");
+  const int32 zero_point = op_params.zero_point;
+  const double scale = static_cast<double>(op_params.scale);
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
+  static constexpr int32 min_val = std::numeric_limits<T>::min();
+  static constexpr int32 max_val = std::numeric_limits<T>::max();
+
+  for (int i = 0; i < flat_size; i++) {
+    const float val = input_data[i];
+    int32 unclamped = static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
+    int32 clamped = std::min(std::max(unclamped, min_val), max_val);
+    output_data[i] = clamped;
+  }
+}
+
 template <typename input_type, typename output_type>
 inline void Requantize(const input_type* input_data, int32_t size,
                        int32_t effective_scale_multiplier,