Enable affine quantized tensor in writer_lib.

PiperOrigin-RevId: 281626472
Change-Id: Ic922833550599b54c966ce65ad76a712f298fd89
diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc
index de75e14..6f9507e 100644
--- a/tensorflow/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/lite/experimental/writer/writer_lib.cc
@@ -164,19 +164,38 @@
       // Primitive type.
       TensorType type = TfLiteTypeToSchemaType(tensor->type);
       // Handle quantization
+      flatbuffers::Offset<QuantizationParameters> quantization_params;
+
       const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
       flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
       flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
-      if (tensor->params.scale != 0.f) {
-        // We have quantization, make a single arugment array (multi channel
-        // quant needs updating here).
-        scale_array = fbb->CreateVector<float>({tensor->params.scale});
-        zero_point_array =
-            fbb->CreateVector<int64_t>({tensor->params.zero_point});
+      // Multi channel quantization.
+      if (tensor->quantization.type == kTfLiteAffineQuantization) {
+        const TfLiteAffineQuantization* params =
+            reinterpret_cast<TfLiteAffineQuantization*>(
+                tensor->quantization.params);
+        const size_t num_scales = params->scale->size;
+
+        const int channel_index = params->quantized_dimension;
+        std::vector<float> scale_vector(
+            {params->scale->data, params->scale->data + num_scales});
+        std::vector<int64_t> zero_point_vector(
+            {params->zero_point->data, params->zero_point->data + num_scales});
+        scale_array = fbb->CreateVector<float>(scale_vector);
+        zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
+        quantization_params = CreateQuantizationParameters(
+            *fbb, null_array, null_array, scale_array, zero_point_array,
+            QuantizationDetails_NONE, 0, channel_index);
+      } else {
+        // Quantization with a single argument array.
+        if (tensor->params.scale != 0.f) {
+          scale_array = fbb->CreateVector<float>({tensor->params.scale});
+          zero_point_array =
+              fbb->CreateVector<int64_t>({tensor->params.zero_point});
+        }
+        quantization_params = CreateQuantizationParameters(
+            *fbb, null_array, null_array, scale_array, zero_point_array);
       }
-      flatbuffers::Offset<QuantizationParameters> quantization_params =
-          CreateQuantizationParameters(*fbb, null_array, null_array,
-                                       scale_array, zero_point_array);
       // Shape
       TfLiteIntArrayView shape_view(tensor->dims);
       std::vector<int> shape =