Enable affine quantized tensor in writer_lib.
PiperOrigin-RevId: 281626472
Change-Id: Ic922833550599b54c966ce65ad76a712f298fd89
diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc
index de75e14..6f9507e 100644
--- a/tensorflow/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/lite/experimental/writer/writer_lib.cc
@@ -164,19 +164,38 @@
// Primitive type.
TensorType type = TfLiteTypeToSchemaType(tensor->type);
// Handle quantization
+ flatbuffers::Offset<QuantizationParameters> quantization_params;
+
const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
- if (tensor->params.scale != 0.f) {
- // We have quantization, make a single arugment array (multi channel
- // quant needs updating here).
- scale_array = fbb->CreateVector<float>({tensor->params.scale});
- zero_point_array =
- fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ // Multi channel quantization.
+ if (tensor->quantization.type == kTfLiteAffineQuantization) {
+ const TfLiteAffineQuantization* params =
+ reinterpret_cast<TfLiteAffineQuantization*>(
+ tensor->quantization.params);
+ const size_t num_scales = params->scale->size;
+
+ const int channel_index = params->quantized_dimension;
+ std::vector<float> scale_vector(
+ {params->scale->data, params->scale->data + num_scales});
+ std::vector<int64_t> zero_point_vector(
+ {params->zero_point->data, params->zero_point->data + num_scales});
+ scale_array = fbb->CreateVector<float>(scale_vector);
+ zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
+ quantization_params = CreateQuantizationParameters(
+ *fbb, null_array, null_array, scale_array, zero_point_array,
+ QuantizationDetails_NONE, 0, channel_index);
+ } else {
+ // Quantization with a single argument array.
+ if (tensor->params.scale != 0.f) {
+ scale_array = fbb->CreateVector<float>({tensor->params.scale});
+ zero_point_array =
+ fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ }
+ quantization_params = CreateQuantizationParameters(
+ *fbb, null_array, null_array, scale_array, zero_point_array);
}
- flatbuffers::Offset<QuantizationParameters> quantization_params =
- CreateQuantizationParameters(*fbb, null_array, null_array,
- scale_array, zero_point_array);
// Shape
TfLiteIntArrayView shape_view(tensor->dims);
std::vector<int> shape =