Added tests for FP32 computations with INT8 models in XNNPACK delegate
diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD
index 9ca0e9f..3d6d0fc 100644
--- a/tensorflow/lite/delegates/xnnpack/BUILD
+++ b/tensorflow/lite/delegates/xnnpack/BUILD
@@ -78,11 +78,22 @@
 ################################ Tester classes ################################
 
 cc_library(
+    name = "test_util",
+    testonly = 1,
+    srcs = ["test_util.cc"],
+    hdrs = ["test_util.h"],
+    deps = [
+        "//tensorflow/lite/kernels/internal:cppmath",
+    ],
+)
+
+cc_library(
     name = "binary_elementwise_tester",
     testonly = 1,
     srcs = ["binary_elementwise_tester.cc"],
     hdrs = ["binary_elementwise_tester.h"],
     deps = [
+        ":test_util",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/c:common",
@@ -101,6 +112,7 @@
     srcs = ["conv_2d_tester.cc"],
     hdrs = ["conv_2d_tester.h"],
     deps = [
+        ":test_util",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/c:common",
@@ -136,6 +148,7 @@
     srcs = ["depthwise_conv_2d_tester.cc"],
     hdrs = ["depthwise_conv_2d_tester.h"],
     deps = [
+        ":test_util",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/c:common",
@@ -154,6 +167,7 @@
     srcs = ["fully_connected_tester.cc"],
     hdrs = ["fully_connected_tester.h"],
     deps = [
+        ":test_util",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/c:common",
@@ -223,6 +237,7 @@
     srcs = ["prelu_tester.cc"],
     hdrs = ["prelu_tester.h"],
     deps = [
+        ":test_util",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/delegates/xnnpack/add_test.cc b/tensorflow/lite/delegates/xnnpack/add_test.cc
index 5731683..69a898d 100644
--- a/tensorflow/lite/delegates/xnnpack/add_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/add_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_ADD, xnnpack_delegate.get());
 }
 
+TEST(Add, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_ADD, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_ADD, xnnpack_delegate.get());
+}
+
 TEST(Add, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc
index 6007ddc..033e5b6 100644
--- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc
@@ -25,6 +25,7 @@
 #include <gtest/gtest.h>
 #include <fp16.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -64,7 +65,7 @@
   if (Input1Static()) {
     ASSERT_FALSE(Input2Static());
   }
-  if (FP16Weights()) {
+  if (FP16Weights() || INT8Weights()) {
     ASSERT_TRUE(Input1Static() || Input2Static());
   }
 
@@ -191,7 +192,7 @@
   flatbuffers::FlatBufferBuilder builder;
   std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
       {CreateOperatorCode(builder, binary_op)}};
-  if (FP16Weights()) {
+  if (FP16Weights() || INT8Weights()) {
     operator_codes.emplace_back(
         CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
   } else if (SparseWeights()) {
@@ -214,6 +215,15 @@
           builder, builder.CreateVector(
                        reinterpret_cast<const uint8_t*>(input1_data.data()),
                        sizeof(uint16_t) * input1_data.size())));
+    } else if (INT8Weights()) {
+      std::vector<int8_t> input1_data(ComputeSize(Input1Shape()));
+      std::generate(input1_data.begin(), input1_data.end(),
+                    std::bind(QuantizeInt8, input1_rng, input1_zero_point_, input1_scale_));
+
+      buffers.push_back(CreateBuffer(
+              builder, builder.CreateVector(
+                      reinterpret_cast<const uint8_t*>(input1_data.data()),
+                      sizeof(int8_t) * input1_data.size())));
     } else {
       std::vector<float> input1_data(ComputeSize(Input1Shape()));
       std::generate(input1_data.begin(), input1_data.end(), input1_rng);
@@ -239,6 +249,15 @@
           builder, builder.CreateVector(
                        reinterpret_cast<const uint8_t*>(input2_data.data()),
                        sizeof(uint16_t) * input2_data.size())));
+    } else if (INT8Weights()) {
+      std::vector<int8_t> input2_data(ComputeSize(Input1Shape()));
+      std::generate(input2_data.begin(), input2_data.end(),
+                    std::bind(QuantizeInt8, input2_rng, input2_zero_point_, input2_scale_));
+
+      buffers.push_back(CreateBuffer(
+              builder, builder.CreateVector(
+                      reinterpret_cast<const uint8_t *>(input2_data.data()),
+                      sizeof(int8_t) * input2_data.size())));
     } else {
       std::vector<float> input2_data(ComputeSize(Input2Shape()));
       std::generate(input2_data.begin(), input2_data.end(), input2_rng);
@@ -262,6 +281,16 @@
                      builder.CreateVector<int32_t>(Input1Shape().data(),
                                                    Input1Shape().size()),
                      TensorType_FLOAT16, 1));
+  } else if (INT8Weights() && Input1Static()) {
+    tensors.emplace_back(
+        CreateTensor(builder,
+                     builder.CreateVector<int32_t>(Input1Shape().data(),
+                                                   Input1Shape().size()),
+                     TensorType_INT8, 1, 0,
+                     CreateQuantizationParameters(
+                         builder, /*min=*/0, /*max=*/0,
+                         builder.CreateVector<float>({input1_scale_}),
+                         builder.CreateVector<int64_t>({input1_zero_point_}))));
   } else if (SparseWeights() && Input1Static()) {
     int dims_count = Input1Shape().size();
     std::vector<flatbuffers::Offset<DimensionMetadata>> dim_metadata(
@@ -288,6 +317,16 @@
                      builder.CreateVector<int32_t>(Input2Shape().data(),
                                                    Input2Shape().size()),
                      TensorType_FLOAT16, 1));
+  } else if (INT8Weights() && Input2Static()) {
+    tensors.emplace_back(
+        CreateTensor(builder,
+                     builder.CreateVector<int32_t>(Input2Shape().data(),
+                                                   Input2Shape().size()),
+                     TensorType_INT8, 1, 0,
+                     CreateQuantizationParameters(
+                         builder, /*min=*/0, /*max=*/0,
+                         builder.CreateVector<float>({input2_scale_}),
+                         builder.CreateVector<int64_t>({input2_zero_point_}))));
   } else if (SparseWeights() && Input2Static()) {
     int dims_count = Input2Shape().size();
     std::vector<flatbuffers::Offset<DimensionMetadata>> dim_metadata(
@@ -308,7 +347,7 @@
         TensorType_FLOAT32, /*buffer=*/1, /*name=*/0, /*quantization=*/0,
         /*is_variable=*/false, /*sparsity=*/sparsity_param));
   }
-  if (FP16Weights()) {
+  if (FP16Weights() || INT8Weights()) {
     const std::array<int32_t, 1> dequantize_inputs{{0}};
     const std::array<int32_t, 1> dequantize_outputs{{Input1Static() ? 1 : 2}};
     operators.emplace_back(CreateOperator(
diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h
index 3d476ba..89a0ef4 100644
--- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h
+++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h
@@ -81,6 +81,13 @@
 
   inline bool FP16Weights() const { return fp16_weights_; }
 
+  inline BinaryElementwiseTester& INT8Weights() {
+    int8_weights_ = true;
+    return *this;
+  }
+
+  inline bool INT8Weights() const { return int8_weights_; }
+
   inline BinaryElementwiseTester& SparseWeights() {
     sparse_weights_ = true;
     return *this;
@@ -129,7 +136,12 @@
   bool input1_static_ = false;
   bool input2_static_ = false;
   bool fp16_weights_ = false;
+  bool int8_weights_ = false;
   bool sparse_weights_ = false;
+  int8_t input1_zero_point_ = 0;
+  int8_t input2_zero_point_ = 0;
+  float input1_scale_ = 0.75f;
+  float input2_scale_ = 1.0f;
   ::tflite::ActivationFunctionType activation_ =
       ::tflite::ActivationFunctionType_NONE;
 };
diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
index b794ee2..efb1add 100644
--- a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
@@ -321,6 +321,38 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(Conv2D, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
+
+  Conv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .INT8Weights()
+      .Test(xnnpack_delegate.get());
+}
+
 TEST(Conv2D, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
@@ -386,6 +418,39 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(Conv2D, SparseINT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
+
+  Conv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .SparseWeights()
+      .INT8Weights()
+      .Test(xnnpack_delegate.get());
+}
+
 TEST(Conv2D, ReluActivation) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc
index 8ce5ae0..b128692 100644
--- a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc
@@ -24,6 +24,7 @@
 #include <gtest/gtest.h>
 #include <fp16.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -130,7 +131,7 @@
         CreateOperatorCode(builder, BuiltinOperator_DENSIFY));
     const std::array<int32_t, 1> densify_filter_inputs{{0}};
     const std::array<int32_t, 1> densify_filter_outputs{
-        {FP16Weights() ? 1 : 2}};
+        {(FP16Weights() || INT8Weights()) ? 1 : 2}};
     operators.emplace_back(CreateOperator(
         builder, /*opcode_index=*/operator_codes.size() - 1,
         builder.CreateVector<int32_t>(densify_filter_inputs.data(),
@@ -200,6 +201,68 @@
                                       dequantize_bias_inputs.size()),
         builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
                                       dequantize_bias_outputs.size())));
+  } else if (INT8Weights()) {
+    operator_codes.emplace_back(
+        CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
+
+    std::vector<int8_t> filter_data(OutputChannels() * KernelHeight() *
+                                    KernelWidth() * InputChannels());
+    std::vector<int8_t> bias_data(OutputChannels());
+    for (int32_t oc = 0; oc < OutputChannels(); oc++) {
+      // Use the same range of all-positive or all-negative values to generate
+      // all weights within the same output channel, but different ranges for
+      // different output channels. This ensures that no catastrophic
+      // cancellation occur, but test covers both positive and negative inputs.
+      const float range = range_rng();
+      auto value_rng =
+          std::bind(QuantizeInt8,
+                    std::bind(std::uniform_real_distribution<float>(
+                                  std::min(range, 0.0f), std::max(range, 0.0f)),
+                              std::ref(rng)),
+                    weights_zero_point_, weights_scale_);
+      bias_data[oc] = value_rng();
+      for (int32_t ic = 0; ic < InputChannels(); ic++) {
+        for (int32_t y = 0; y < KernelHeight(); y++) {
+          for (int32_t x = 0; x < KernelWidth(); x++) {
+            const int32_t index =
+                ((oc * KernelHeight() + y) * KernelWidth() + x) *
+                    InputChannels() +
+                ic;
+            filter_data[index] = value_rng();
+          }
+        }
+      }
+    }
+
+    buffers.emplace_back(CreateBuffer(
+        builder, builder.CreateVector(
+                     reinterpret_cast<const uint8_t*>(filter_data.data()),
+                     sizeof(int8_t) * filter_data.size())));
+    buffers.emplace_back(CreateBuffer(
+        builder,
+        builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
+                             sizeof(int8_t) * bias_data.size())));
+
+    const std::array<int32_t, 1> dequantize_filter_inputs{
+        {SparseWeights() ? 1 : 0}};
+    const std::array<int32_t, 1> dequantize_filter_outputs{
+        {SparseWeights() ? 4 : 3}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/operator_codes.size() - 1,
+        builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
+                                      dequantize_filter_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
+                                      dequantize_filter_outputs.size())));
+    const std::array<int32_t, 1> dequantize_bias_inputs{
+        {SparseWeights() ? 2 : 1}};
+    const std::array<int32_t, 1> dequantize_bias_outputs{
+        {SparseWeights() ? 5 : 4}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/operator_codes.size() - 1,
+        builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
+                                      dequantize_bias_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
+                                      dequantize_bias_outputs.size())));
   } else {
     std::vector<float> filter_data(OutputChannels() * KernelHeight() *
                                    KernelWidth() * InputChannels());
@@ -265,12 +328,25 @@
     flatbuffers::Offset<SparsityParameters> sparsity_param =
         CreateSparsityParameters(builder, builder.CreateVector(traversal_order),
                                  0, builder.CreateVector(dim_metadata));
-    tensors.emplace_back(CreateTensor(
-        builder,
-        builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
-        /*type=*/FP16Weights() ? TensorType_FLOAT16 : TensorType_FLOAT32,
-        /*buffer=*/1, /*name=*/0, /*quantization=*/0,
-        /*is_variable=*/false, /*sparsity=*/sparsity_param));
+    if (INT8Weights()) {
+      tensors.emplace_back(CreateTensor(
+          builder,
+          builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+          /*type=*/TensorType_INT8,
+          /*buffer=*/1, /*name=*/0,
+          CreateQuantizationParameters(
+              builder, /*min=*/0, /*max=*/0,
+              builder.CreateVector<float>({weights_scale_}),
+              builder.CreateVector<int64_t>({weights_zero_point_})),
+          /*is_variable=*/false, /*sparsity=*/sparsity_param));
+    } else {
+      tensors.emplace_back(CreateTensor(
+          builder,
+          builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+          /*type=*/FP16Weights() ? TensorType_FLOAT16 : TensorType_FLOAT32,
+          /*buffer=*/1, /*name=*/0, /*quantization=*/0,
+          /*is_variable=*/false, /*sparsity=*/sparsity_param));
+    }
   }
   if (FP16Weights()) {
     tensors.emplace_back(CreateTensor(
@@ -281,6 +357,23 @@
         builder,
         builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
         TensorType_FLOAT16, /*buffer=*/2));
+  } else if (INT8Weights()) {
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+        TensorType_INT8, /*buffer=*/SparseWeights() ? 0 : 1, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
+        TensorType_INT8, /*buffer=*/2, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
   }
   tensors.emplace_back(CreateTensor(
       builder,
@@ -289,11 +382,12 @@
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() || SparseWeights() ? 0 : 1));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights() ||
+                                      SparseWeights()) ? 0 : 1));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights()) ? 0 : 2));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.h b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.h
index d0a021c..f94730b 100644
--- a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.h
+++ b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.h
@@ -155,6 +155,13 @@
 
   inline bool FP16Weights() const { return fp16_weights_; }
 
+  inline Conv2DTester& INT8Weights() {
+    int8_weights_ = true;
+    return *this;
+  }
+
+  inline bool INT8Weights() const { return int8_weights_; }
+
   inline Conv2DTester& SparseWeights() {
     sparse_weights_ = true;
     return *this;
@@ -220,7 +227,10 @@
   int32_t dilation_height_ = 1;
   int32_t dilation_width_ = 1;
   bool fp16_weights_ = false;
+  bool int8_weights_ = false;
   bool sparse_weights_ = false;
+  int8_t weights_zero_point_ = 0;
+  float weights_scale_ = 0.75f;
   ::tflite::Padding padding_ = ::tflite::Padding_VALID;
   ::tflite::ActivationFunctionType activation_ =
       ::tflite::ActivationFunctionType_NONE;
diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc
index ad8b69a..d45529f 100644
--- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc
@@ -402,6 +402,37 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(DepthwiseConv2D, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
+
+  DepthwiseConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .INT8Weights()
+      .Test(xnnpack_delegate.get());
+}
+
 TEST(DepthwiseConv2D, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc
index a14fc15..816eb91 100644
--- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc
@@ -24,6 +24,7 @@
 #include <gtest/gtest.h>
 #include <fp16.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -181,6 +182,63 @@
                                       dequantize_bias_inputs.size()),
         builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
                                       dequantize_bias_outputs.size())));
+  } else if (INT8Weights()) {
+    operator_codes.emplace_back(
+        CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
+
+    std::vector<int8_t> filter_data(KernelHeight() * KernelWidth() *
+                                    OutputChannels());
+    std::vector<int8_t> bias_data(OutputChannels());
+    for (int32_t ic = 0; ic < InputChannels(); ic++) {
+      // Use the same range of all-positive or all-negative values to generate
+      // all pixels within the same batch index & channel, but different ranges
+      // for different channels or batches. This ensures that no catastrophic
+      // cancellation occur, but test covers both positive and negative inputs.
+      const float range = range_rng();
+      auto value_rng =
+          std::bind(QuantizeInt8,
+                    std::bind(std::uniform_real_distribution<float>(
+                                  std::min(range, 0.0f), std::max(range, 0.0f)),
+                              std::ref(rng)),
+                    weights_zero_point_, weights_scale_);
+      for (int32_t m = 0; m < DepthMultiplier(); m++) {
+        const int32_t oc = ic * DepthMultiplier() + m;
+        bias_data[oc] = value_rng();
+        for (int32_t y = 0; y < KernelHeight(); y++) {
+          for (int32_t x = 0; x < KernelWidth(); x++) {
+            const int32_t index =
+                (y * KernelWidth() + x) * OutputChannels() + oc;
+            filter_data[index] = value_rng();
+          }
+        }
+      }
+    }
+
+    buffers.emplace_back(CreateBuffer(
+        builder, builder.CreateVector(
+                     reinterpret_cast<const uint8_t*>(filter_data.data()),
+                     sizeof(int8_t) * filter_data.size())));
+    buffers.emplace_back(CreateBuffer(
+        builder,
+        builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
+                             sizeof(int8_t) * bias_data.size())));
+
+    const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
+    const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/1,
+        builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
+                                      dequantize_filter_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
+                                      dequantize_filter_outputs.size())));
+    const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
+    const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/1,
+        builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
+                                      dequantize_bias_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
+                                      dequantize_bias_outputs.size())));
   } else {
     std::vector<float> filter_data(KernelHeight() * KernelWidth() *
                                    OutputChannels());
@@ -249,6 +307,23 @@
         builder,
         builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
         TensorType_FLOAT16, /*buffer=*/2));
+  } else if (INT8Weights()) {
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+        TensorType_INT8, /*buffer=*/1, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
+        TensorType_INT8, /*buffer=*/2, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
   } else if (SparseWeights()) {
     // Sparse tensor in TFLite can be in different formats. Here we choose the
     // simplest configuration that
@@ -280,11 +355,12 @@
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() || SparseWeights() ? 0 : 1));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights() ||
+                                      SparseWeights()) ? 0 : 1));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights()) ? 0 : 2));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h
index acc82ed..ea27cb7 100644
--- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h
+++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h
@@ -159,6 +159,13 @@
 
   inline bool FP16Weights() const { return fp16_weights_; }
 
+  inline DepthwiseConv2DTester& INT8Weights() {
+    int8_weights_ = true;
+    return *this;
+  }
+
+  inline bool INT8Weights() const { return int8_weights_; }
+
   inline DepthwiseConv2DTester& SparseWeights() {
     sparse_weights_ = true;
     return *this;
@@ -224,7 +231,10 @@
   int32_t dilation_height_ = 1;
   int32_t dilation_width_ = 1;
   bool fp16_weights_ = false;
+  bool int8_weights_ = false;
   bool sparse_weights_ = false;
+  int8_t weights_zero_point_ = 0;
+  float weights_scale_ = 0.75f;
   ::tflite::Padding padding_ = ::tflite::Padding_VALID;
   ::tflite::ActivationFunctionType activation_ =
       ::tflite::ActivationFunctionType_NONE;
diff --git a/tensorflow/lite/delegates/xnnpack/div_test.cc b/tensorflow/lite/delegates/xnnpack/div_test.cc
index 5e085ca..b338870 100644
--- a/tensorflow/lite/delegates/xnnpack/div_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/div_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_DIV, xnnpack_delegate.get());
 }
 
+TEST(Div, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_DIV, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_DIV, xnnpack_delegate.get());
+}
+
 TEST(Div, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc b/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc
index 0dffd1d..c0dd8c0 100644
--- a/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc
@@ -251,6 +251,29 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(FullyConnected, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  auto channels_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 9), std::ref(rng));
+  const auto batch = batch_rng();
+  const auto input_channels = channels_rng();
+  const auto output_channels = channels_rng();
+
+  FullyConnectedTester()
+      .InputShape({batch, input_channels})
+      .InputChannels(input_channels)
+      .OutputChannels(output_channels)
+      .INT8Weights()
+      .Test(xnnpack_delegate.get());
+}
+
 TEST(FullyConnected, ReluActivation) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc
index c55555d..6e81068 100644
--- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc
@@ -25,6 +25,7 @@
 #include <gtest/gtest.h>
 #include <fp16.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -175,6 +176,57 @@
                                       dequantize_bias_inputs.size()),
         builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
                                       dequantize_bias_outputs.size())));
+  } else if (INT8Weights()) {
+    operator_codes.emplace_back(
+        CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
+
+    std::vector<int8_t> filter_data(InputChannels() * OutputChannels());
+    std::vector<int8_t> bias_data(OutputChannels());
+
+    for (int32_t oc = 0; oc < OutputChannels(); oc++) {
+      // Use the same range of all-positive or all-negative values to generate
+      // all filter & bias weights within the same channel, but different ranges
+      // for different output channels. This ensures that no catastrophic
+      // cancellation occur, but test covers both positive and negative inputs.
+      const float range = range_rng();
+      auto value_rng =
+          std::bind(QuantizeInt8,
+                    std::bind(std::uniform_real_distribution<float>(
+                                  std::min(range, 0.0f), std::max(range, 0.0f)),
+                              std::ref(rng)),
+                    weights_zero_point_, weights_scale_);
+
+      bias_data[oc] = value_rng();
+      for (int32_t ic = 0; ic < InputChannels(); ic++) {
+        filter_data[oc * InputChannels() + ic] = value_rng();
+      }
+    }
+
+    buffers.emplace_back(CreateBuffer(
+        builder, builder.CreateVector(
+                     reinterpret_cast<const uint8_t*>(filter_data.data()),
+                     sizeof(int8_t) * filter_data.size())));
+    buffers.emplace_back(CreateBuffer(
+        builder,
+        builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
+                             sizeof(int8_t) * bias_data.size())));
+
+    const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
+    const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/1,
+        builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
+                                      dequantize_filter_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
+                                      dequantize_filter_outputs.size())));
+    const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
+    const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
+    operators.emplace_back(CreateOperator(
+        builder, /*opcode_index=*/1,
+        builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
+                                      dequantize_bias_inputs.size()),
+        builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
+                                      dequantize_bias_outputs.size())));
   } else {
     std::vector<float> filter_data(InputChannels() * OutputChannels());
     std::vector<float> bias_data(OutputChannels());
@@ -221,6 +273,23 @@
         builder,
         builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
         TensorType_FLOAT16, /*buffer=*/2));
+  } else if (INT8Weights()){
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+        TensorType_INT8, /*buffer=*/1, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
+    tensors.emplace_back(CreateTensor(
+        builder,
+        builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
+        TensorType_INT8, /*buffer=*/2, /*name=*/0,
+        CreateQuantizationParameters(
+            builder, /*min=*/0, /*max=*/0,
+            builder.CreateVector<float>({weights_scale_}),
+            builder.CreateVector<int64_t>({weights_zero_point_}))));
   }
   tensors.emplace_back(CreateTensor(
       builder,
@@ -229,11 +298,11 @@
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights()) ? 0 : 1));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
-      TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
+      TensorType_FLOAT32, /*buffer=*/(FP16Weights() || INT8Weights()) ? 0 : 2));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h
index 6350bc8..d3e2fa5 100644
--- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h
+++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h
@@ -78,6 +78,13 @@
 
   inline bool FP16Weights() const { return fp16_weights_; }
 
+  inline FullyConnectedTester& INT8Weights() {
+    int8_weights_ = true;
+    return *this;
+  }
+
+  inline bool INT8Weights() const { return int8_weights_; }
+
   inline FullyConnectedTester& ReluActivation() {
     activation_ = ::tflite::ActivationFunctionType_RELU;
     return *this;
@@ -110,6 +117,9 @@
   int32_t output_channels_ = 1;
   bool keep_dims_ = false;
   bool fp16_weights_ = false;
+  bool int8_weights_ = false;
+  int8_t weights_zero_point_ = 0;
+  float weights_scale_ = 0.75f;
   ::tflite::ActivationFunctionType activation_ =
       ::tflite::ActivationFunctionType_NONE;
 };
diff --git a/tensorflow/lite/delegates/xnnpack/maximum_test.cc b/tensorflow/lite/delegates/xnnpack/maximum_test.cc
index bad10f8..cc6f8c6 100644
--- a/tensorflow/lite/delegates/xnnpack/maximum_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/maximum_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_MAXIMUM, xnnpack_delegate.get());
 }
 
+TEST(Maximum, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MAXIMUM, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MAXIMUM, xnnpack_delegate.get());
+}
+
 TEST(Maximum, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/minimum_test.cc b/tensorflow/lite/delegates/xnnpack/minimum_test.cc
index 0b564cf..2c44220 100644
--- a/tensorflow/lite/delegates/xnnpack/minimum_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/minimum_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_MINIMUM, xnnpack_delegate.get());
 }
 
+TEST(Minimum, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MINIMUM, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MINIMUM, xnnpack_delegate.get());
+}
+
 TEST(Minimum, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/mul_test.cc b/tensorflow/lite/delegates/xnnpack/mul_test.cc
index 39d1643..4446aeb 100644
--- a/tensorflow/lite/delegates/xnnpack/mul_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/mul_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_MUL, xnnpack_delegate.get());
 }
 
+TEST(Mul, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MUL, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_MUL, xnnpack_delegate.get());
+}
+
 TEST(Mul, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/prelu_test.cc b/tensorflow/lite/delegates/xnnpack/prelu_test.cc
index 1002691..2e177b9 100644
--- a/tensorflow/lite/delegates/xnnpack/prelu_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/prelu_test.cc
@@ -535,6 +535,27 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(Prelu, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  PreluTester()
+      .InputShape({batch, height, width, channels})
+      .SlopeShape({channels})
+      .INT8Weights()
+      .Test(xnnpack_delegate.get());
+}
+
 TEST(Prelu, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/prelu_tester.cc b/tensorflow/lite/delegates/xnnpack/prelu_tester.cc
index 4963424..bd4f498 100644
--- a/tensorflow/lite/delegates/xnnpack/prelu_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/prelu_tester.cc
@@ -25,6 +25,7 @@
 #include <gtest/gtest.h>
 #include <fp16.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -106,7 +107,7 @@
   flatbuffers::FlatBufferBuilder builder;
   std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
       {CreateOperatorCode(builder, BuiltinOperator_PRELU)}};
-  if (FP16Weights()) {
+  if (FP16Weights() || INT8Weights()) {
     operator_codes.emplace_back(
         CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
   } else if (SparseWeights()) {
@@ -127,6 +128,15 @@
         builder, builder.CreateVector(
                      reinterpret_cast<const uint8_t*>(slope_data.data()),
                      sizeof(uint16_t) * slope_data.size())));
+  } else if (INT8Weights()) {
+    std::vector<int8_t> slope_data(ComputeSize(SlopeShape()));
+    std::generate(slope_data.begin(), slope_data.end(),
+                  std::bind(QuantizeInt8, slope_rng, slope_zero_point_, slope_scale_));
+
+    buffers.push_back(CreateBuffer(
+        builder, builder.CreateVector(
+                     reinterpret_cast<const uint8_t*>(slope_data.data()),
+                     sizeof(int8_t) * slope_data.size())));
   } else {
     std::vector<float> slope_data(ComputeSize(SlopeShape()));
     std::generate(slope_data.begin(), slope_data.end(), slope_rng);
@@ -144,6 +154,16 @@
         builder,
         builder.CreateVector<int32_t>(SlopeShape().data(), SlopeShape().size()),
         TensorType_FLOAT16, /*buffer=*/1));
+  } else if (INT8Weights()) {
+    tensors.emplace_back(
+        CreateTensor(builder,
+                     builder.CreateVector<int32_t>(SlopeShape().data(),
+                                                   SlopeShape().size()),
+                     TensorType_INT8, /*buffer=*/1, /*name=*/0,
+                     CreateQuantizationParameters(
+                             builder, /*min=*/0, /*max=*/0,
+                             builder.CreateVector<float>({slope_scale_}),
+                             builder.CreateVector<int64_t>({slope_zero_point_}))));
   } else if (SparseWeights()) {
     const int dims_count = SlopeShape().size();
     std::vector<flatbuffers::Offset<DimensionMetadata>> dim_metadata(
@@ -163,7 +183,7 @@
         TensorType_FLOAT32, /*buffer=*/1, /*name=*/0, /*quantization=*/0,
         /*is_variable=*/false, /*sparsity=*/sparsity_param));
   }
-  if (FP16Weights()) {
+  if (FP16Weights() || INT8Weights()) {
     const std::array<int32_t, 1> dequantize_inputs{{0}};
     const std::array<int32_t, 1> dequantize_outputs{{2}};
     operators.emplace_back(CreateOperator(
@@ -190,7 +210,7 @@
       builder,
       builder.CreateVector<int32_t>(SlopeShape().data(), SlopeShape().size()),
       TensorType_FLOAT32,
-      /*buffer=*/(FP16Weights() || SparseWeights()) ? 0 : 1));
+      /*buffer=*/(FP16Weights() || INT8Weights() || SparseWeights()) ? 0 : 1));
   tensors.emplace_back(CreateTensor(
       builder,
       builder.CreateVector<int32_t>(OutputShape().data(), OutputShape().size()),
diff --git a/tensorflow/lite/delegates/xnnpack/prelu_tester.h b/tensorflow/lite/delegates/xnnpack/prelu_tester.h
index e89bae6..32b74cc 100644
--- a/tensorflow/lite/delegates/xnnpack/prelu_tester.h
+++ b/tensorflow/lite/delegates/xnnpack/prelu_tester.h
@@ -62,6 +62,13 @@
 
   inline bool FP16Weights() const { return fp16_weights_; }
 
+  inline PreluTester& INT8Weights() {
+    int8_weights_ = true;
+    return *this;
+  }
+
+  inline bool INT8Weights() const { return int8_weights_; }
+
   inline PreluTester& SparseWeights() {
     sparse_weights_ = true;
     return *this;
@@ -79,7 +86,10 @@
   std::vector<int32_t> input_shape_;
   std::vector<int32_t> slope_shape_;
   bool fp16_weights_ = false;
+  bool int8_weights_ = false;
   bool sparse_weights_ = false;
+  int8_t slope_zero_point_ = 0;
+  float slope_scale_ = 0.75f;
 };
 
 }  // namespace xnnpack
diff --git a/tensorflow/lite/delegates/xnnpack/quantization_util.cc b/tensorflow/lite/delegates/xnnpack/quantization_util.cc
index 9217e69..466a9e0 100644
--- a/tensorflow/lite/delegates/xnnpack/quantization_util.cc
+++ b/tensorflow/lite/delegates/xnnpack/quantization_util.cc
@@ -1,4 +1,4 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -21,22 +21,10 @@
 #include <fp16.h>
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/lite/kernels/internal/types.h"
-#include "tensorflow/lite/kernels/internal/cppmath.h"
 
 namespace tflite {
 namespace xnnpack {
 
-int8_t QuantizeInt8(float value, int32_t zero_point, double scale) {
-  static constexpr int32_t min_val = std::numeric_limits<int8_t>::min();
-  static constexpr int32_t max_val = std::numeric_limits<int8_t>::max();
-
-  int32_t unclamped =
-      static_cast<int32_t>(TfLiteRound(value / static_cast<float>(scale))) +
-      zero_point;
-  int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
-  return static_cast<int8_t>(clamped);
-}
-
 void DequantizeFloat16(const uint16_t *packed_fp16_data, float *unpacked_fp32_data,
                        size_t tensor_elements) {
   for (size_t i = 0; i < tensor_elements; ++i) {
diff --git a/tensorflow/lite/delegates/xnnpack/quantization_util.h b/tensorflow/lite/delegates/xnnpack/quantization_util.h
index bfa589b..86f2451 100644
--- a/tensorflow/lite/delegates/xnnpack/quantization_util.h
+++ b/tensorflow/lite/delegates/xnnpack/quantization_util.h
@@ -24,9 +24,6 @@
 namespace tflite {
 namespace xnnpack {
 
-// Only used for testing
-int8_t QuantizeInt8(float value, int32_t zero_point, double scale);
-
 void DequantizeInt8(const int8_t* packed_s8_data, float* unpacked_fp32_data,
                     const RuntimeShape& tensor_shape,
                     int32_t zero_point, double scale);
diff --git a/tensorflow/lite/delegates/xnnpack/squared_difference_test.cc b/tensorflow/lite/delegates/xnnpack/squared_difference_test.cc
index 75324c0..1848542 100644
--- a/tensorflow/lite/delegates/xnnpack/squared_difference_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/squared_difference_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_SQUARED_DIFFERENCE, xnnpack_delegate.get());
 }
 
+TEST(SquaredDifference, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_SQUARED_DIFFERENCE, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_SQUARED_DIFFERENCE, xnnpack_delegate.get());
+}
+
 TEST(SquaredDifference, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/sub_test.cc b/tensorflow/lite/delegates/xnnpack/sub_test.cc
index e35deb4..ab4ae82 100644
--- a/tensorflow/lite/delegates/xnnpack/sub_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/sub_test.cc
@@ -708,6 +708,35 @@
       .Test(BuiltinOperator_SUB, xnnpack_delegate.get());
 }
 
+TEST(Sub, INT8Weights) {
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto shape_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
+  const auto batch = shape_rng();
+  const auto height = shape_rng();
+  const auto width = shape_rng();
+  const auto channels = shape_rng();
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input1Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_SUB, xnnpack_delegate.get());
+
+  BinaryElementwiseTester()
+      .Input1Shape({batch, height, width, channels})
+      .Input2Shape({batch, height, width, channels})
+      .Input2Static(true)
+      .INT8Weights()
+      .Test(BuiltinOperator_SUB, xnnpack_delegate.get());
+}
+
 TEST(Sub, SparseWeights) {
   std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
       xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
diff --git a/tensorflow/lite/delegates/xnnpack/test_util.cc b/tensorflow/lite/delegates/xnnpack/test_util.cc
new file mode 100644
index 0000000..4000ed3
--- /dev/null
+++ b/tensorflow/lite/delegates/xnnpack/test_util.cc
@@ -0,0 +1,38 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/xnnpack/test_util.h"
+
+#include <algorithm>
+#include <limits>
+
+#include "tensorflow/lite/kernels/internal/cppmath.h"
+
+namespace tflite {
+namespace xnnpack {
+
+int8_t QuantizeInt8(float value, int32_t zero_point, double scale) {
+  static constexpr int32_t min_val = std::numeric_limits<int8_t>::min();
+  static constexpr int32_t max_val = std::numeric_limits<int8_t>::max();
+
+  int32_t unclamped =
+      static_cast<int32_t>(TfLiteRound(value / static_cast<float>(scale))) +
+      zero_point;
+  int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
+  return static_cast<int8_t>(clamped);
+}
+
+}  // namespace xnnpack
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/xnnpack/test_util.h b/tensorflow/lite/delegates/xnnpack/test_util.h
new file mode 100644
index 0000000..a604c72
--- /dev/null
+++ b/tensorflow/lite/delegates/xnnpack/test_util.h
@@ -0,0 +1,29 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_TEST_UTIL_H_
+#define TENSORFLOW_LITE_DELEGATES_XNNPACK_TEST_UTIL_H_
+
+#include <cstdint>
+
+namespace tflite {
+namespace xnnpack {
+
+int8_t QuantizeInt8(float value, int32_t zero_point, double scale);
+
+}  // namespace xnnpack
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_TEST_UTIL_H_