Add option to revert to previous hybrid quantization scheme
PiperOrigin-RevId: 322886366
Change-Id: I7665938cfecd27e1be09ccfdd145ec02eb03d675
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc
index 8bef019..e4840ae 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights.cc
@@ -130,7 +130,8 @@
// Returns true if the operator supports hybrid evaluation.
bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
- const CustomOpMap& custom_op_map) {
+ const CustomOpMap& custom_op_map,
+ bool use_updated_hybrid_scheme) {
const BuiltinOperator builtin_op_code = op_code->builtin_code;
// Operations that support hybrid evaluation.
bool eval_hybrid = false;
@@ -144,7 +145,6 @@
}
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
builtin_op_code == BuiltinOperator_CONV_2D ||
- builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
builtin_op_code == BuiltinOperator_SVDF ||
builtin_op_code == BuiltinOperator_RNN ||
builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
@@ -158,6 +158,8 @@
if (options->kernel_type == LSTMKernelType_FULL) {
eval_hybrid = true;
}
+ } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
+ eval_hybrid = use_updated_hybrid_scheme;
}
return eval_hybrid;
}
@@ -191,7 +193,7 @@
const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map,
absl::flat_hash_map<int32_t, TensorPerChannel>* tensor_map,
- int subgraph_index) {
+ int subgraph_index, bool use_updated_hybrid_scheme) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
@@ -231,43 +233,46 @@
}
if (op_code->builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
- tensor_map->insert(
- {tensor_idx, {tensor, /*is_per_channel=*/true, /*dim=*/3}});
+ tensor_map->insert({tensor_idx,
+ {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
+ /*dim=*/3}});
} else if (op_code->builtin_code == BuiltinOperator_CONV_2D) {
- tensor_map->insert(
- {tensor_idx, {tensor, /*is_per_channel=*/true, /*dim=*/0}});
+ tensor_map->insert({tensor_idx,
+ {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
+ /*dim=*/0}});
} else {
switch (op_code->builtin_code) {
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
op->builtin_options.AsBidirectionalSequenceLSTMOptions()
- ->asymmetric_quantize_inputs = true;
+ ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
op->builtin_options.AsBidirectionalSequenceRNNOptions()
- ->asymmetric_quantize_inputs = true;
+ ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_FULLY_CONNECTED:
op->builtin_options.AsFullyConnectedOptions()
- ->asymmetric_quantize_inputs = true;
+ ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_LSTM:
op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
- true;
+ use_updated_hybrid_scheme;
break;
case BuiltinOperator_RNN:
- op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = true;
+ op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs =
+ use_updated_hybrid_scheme;
break;
case BuiltinOperator_SVDF:
op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs =
- true;
+ use_updated_hybrid_scheme;
break;
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
op->builtin_options.AsUnidirectionalSequenceLSTMOptions()
- ->asymmetric_quantize_inputs = true;
+ ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
op->builtin_options.AsSequenceRNNOptions()
- ->asymmetric_quantize_inputs = true;
+ ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
break;
default:
break;
@@ -323,25 +328,27 @@
}
// Updates operator code versions for the operators with INT8 inputs.
-void UpdateInt8OperatorVersions(ModelT* model) {
+void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) {
for (int i = 0; i < model->operator_codes.size(); ++i) {
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
- if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+ if (op_code == BuiltinOperator_RNN ||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
- op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
- op_code == BuiltinOperator_RNN ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2;
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
model->operator_codes[i]->version = 3;
- } else if (op_code == BuiltinOperator_LSTM ||
- op_code == BuiltinOperator_SVDF) {
- model->operator_codes[i]->version = 4;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3;
} else if (op_code == BuiltinOperator_CONV_2D) {
- model->operator_codes[i]->version = 5;
+ model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2;
+ } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
+ model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3;
+ } else if (op_code == BuiltinOperator_SVDF) {
+ model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2;
} else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
model->operator_codes[i]->version = 6;
- } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
- model->operator_codes[i]->version = 9;
}
}
}
@@ -402,7 +409,8 @@
const Model* input_model,
bool use_hybrid_evaluation,
uint64_t weights_min_num_elements,
- const CustomOpMap& custom_op_map) {
+ const CustomOpMap& custom_op_map,
+ bool use_updated_hybrid_scheme) {
std::unique_ptr<ModelT> model;
model.reset(input_model->UnPack());
@@ -415,7 +423,7 @@
OperatorT* op = subgraph->operators[i].get();
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map,
- subgraph_index));
+ subgraph_index, use_updated_hybrid_scheme));
}
for (std::pair<int32_t, TensorPerChannel> tensor_pair : tensor_map) {
@@ -456,8 +464,8 @@
// dequantization we need to add a Dequantize op.
bool eval_hybrid =
use_hybrid_evaluation &&
- IsHybridEvaluationOp(consumer_op, consumer_op_code,
- custom_op_map) &&
+ IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map,
+ use_updated_hybrid_scheme) &&
CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code,
custom_op_map) &&
IsQuantizedInput(consumer_op_code, custom_op_map,
@@ -516,7 +524,7 @@
}
// Update the modified operator code versions.
- UpdateInt8OperatorVersions(model.get());
+ UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme);
flatbuffers::Offset<Model> output_model_location =
Model::Pack(*builder, model.get());
@@ -611,7 +619,8 @@
// kWeightsMinSizeDefault elements are quantized.
CustomOpMap custom_op_map;
return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
- weights_min_num_elements, custom_op_map);
+ weights_min_num_elements, custom_op_map,
+ kUseUpdatedHybridSchemeDefault);
}
} // namespace internal
@@ -620,7 +629,8 @@
uint64_t weights_min_num_elements) {
CustomOpMap custom_op_map;
return QuantizeWeightsInt8(builder, input_model, true,
- weights_min_num_elements, custom_op_map);
+ weights_min_num_elements, custom_op_map,
+ kUseUpdatedHybridSchemeDefault);
}
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
@@ -631,7 +641,8 @@
// kWeightsMinSizeDefault elements are quantized.
CustomOpMap custom_op_map;
return QuantizeWeightsInt8(builder, input_model, true,
- kWeightsMinNumElementsDefault, custom_op_map);
+ kWeightsMinNumElementsDefault, custom_op_map,
+ kUseUpdatedHybridSchemeDefault);
}
case BufferType::QUANTIZED_FLOAT16:
return QuantizeWeightsFloat16(builder, input_model);
@@ -643,7 +654,19 @@
uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map) {
return QuantizeWeightsInt8(builder, input_model, true,
- weights_min_num_elements, custom_op_map);
+ weights_min_num_elements, custom_op_map,
+ kUseUpdatedHybridSchemeDefault);
+}
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements,
+ const CustomOpMap& custom_op_map,
+ bool use_updated_hybrid_scheme) {
+ return QuantizeWeightsInt8(builder, input_model,
+ /*use_hybrid_evaluation=*/true,
+ weights_min_num_elements, custom_op_map,
+ use_updated_hybrid_scheme);
}
} // namespace optimize
diff --git a/tensorflow/lite/tools/optimize/quantize_weights.h b/tensorflow/lite/tools/optimize/quantize_weights.h
index 528614f..9212c9a 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights.h
+++ b/tensorflow/lite/tools/optimize/quantize_weights.h
@@ -29,6 +29,13 @@
// Supported resulting types from quantization process.
enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 };
+// This macro is for internal use for conversions requiring previous behavior.
+#ifdef TFLITE_USE_PREVIOUS_HYBRID_SCHEME
+constexpr bool kUseUpdatedHybridSchemeDefault = false;
+#else
+constexpr bool kUseUpdatedHybridSchemeDefault = true;
+#endif
+
// Quantizes input_model and populates the provided builder with the new model.
// By default only weights tensors weight more than 1024 elements will be
// quantized.
@@ -61,6 +68,14 @@
uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map);
+// Same as above, but if use updated_hybrid_scheme is false,
+// use previous quantization scheme.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements,
+ const CustomOpMap& custom_op_map,
+ bool use_updated_hybrid_scheme);
+
namespace internal {
// If use_hybrid_evaluation is false, will disable using hybrid eval for
// operations that support it.
diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc
index 2f92a9a..94bff2d 100644
--- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc
@@ -216,7 +216,11 @@
EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
<< quant_tensor->name()->str();
auto shape = GetAsVector(quant_tensor->shape());
- EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
+ if (kUseUpdatedHybridSchemeDefault) {
+ EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]);
+ } else {
+ EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
+ }
} else {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
}
@@ -533,6 +537,58 @@
EXPECT_EQ(num_custom_ops_found, 1);
}
+TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) {
+ LoadBasicModel();
+ flatbuffers::FlatBufferBuilder builder;
+ const CustomOpMap custom_op_map;
+ auto status = QuantizeWeights(&builder, model_, 0, custom_op_map, false);
+ EXPECT_EQ(status, kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+ ASSERT_TRUE(output_model);
+
+ // Nothing should change.
+ ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
+ for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
+ subgraph_idx++) {
+ const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
+ const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
+ ASSERT_EQ(quantized_graph->tensors()->size(),
+ float_graph->tensors()->size());
+ // Make sure the graph only has one Conv operation.
+ ASSERT_EQ(quantized_graph->operators()->size(), 1);
+ const auto op = quantized_graph->operators()->Get(0);
+ const uint32_t op_code_idx = op->opcode_index();
+ ASSERT_EQ(output_model->operator_codes()->Get(op_code_idx)->builtin_code(),
+ BuiltinOperator_CONV_2D);
+ for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) {
+ const auto quant_tensor = quantized_graph->tensors()->Get(i);
+ const auto float_tensor = float_graph->tensors()->Get(i);
+ EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer());
+ EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable());
+ EXPECT_EQ(GetAsVector(quant_tensor->shape()),
+ GetAsVector(float_tensor->shape()));
+ EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str());
+ // If the tensor is a weight, it should have type INT8, otherwise it
+ // should stay with type FLOAT32.
+ // If the tensor is a bias, it should have type FLOAT32.
+ if (quant_tensor->name()->str() == "conv_bias") {
+ EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+ } else if (IsModelInputOrOutput(output_model, i)) {
+ EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+ } else if (quant_tensor->buffer() != 0) {
+ EXPECT_EQ(quant_tensor->type(), TensorType_INT8)
+ << quant_tensor->name()->str();
+ auto shape = GetAsVector(quant_tensor->shape());
+ EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1);
+ } else {
+ EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
+ }
+ }
+ }
+}
+
} // namespace
} // namespace optimize
} // namespace tflite