Refactor operator versioning
- Introduced tools/versioning package to maintain operator versioning logic.
- Let SimpleOperator constructor have builtin op enum.
- Test cases of op_version_test.cc are built with reference to operator_test.cc.
PiperOrigin-RevId: 269512025
diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD
index 4fff36f..43cb88d 100644
--- a/tensorflow/lite/toco/tflite/BUILD
+++ b/tensorflow/lite/toco/tflite/BUILD
@@ -32,6 +32,7 @@
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/toco:graph_transformations",
"//tensorflow/lite/toco:model",
+ "//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/memory",
"@flatbuffers",
],
diff --git a/tensorflow/lite/toco/tflite/builtin_operator.h b/tensorflow/lite/toco/tflite/builtin_operator.h
index ea012ff..070fab4 100644
--- a/tensorflow/lite/toco/tflite/builtin_operator.h
+++ b/tensorflow/lite/toco/tflite/builtin_operator.h
@@ -36,7 +36,8 @@
using TfLiteOptions = T2;
BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type)
- : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {}
+ : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
+ builtin_op_(op) {}
// Build the configuration object in the given flatbuffer builder. Return
// its offset.
@@ -65,6 +66,16 @@
}
return std::unique_ptr<Operator>(op.release());
}
+
+ int GetVersion(const OperatorSignature& op_signature) const override {
+ return ::tflite::GetBuiltinOperatorVersion(
+ GetVersioningOpSig(builtin_op_, op_signature));
+ }
+
+ ::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
+
+ private:
+ const ::tflite::BuiltinOperator builtin_op_;
};
} // namespace tflite
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index d067199..1dc653e 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -14,6 +14,8 @@
==============================================================================*/
#include "tensorflow/lite/toco/tflite/operator.h"
+#include <map>
+
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -30,6 +32,7 @@
#include "tensorflow/lite/toco/tflite/custom_operator.h"
#include "tensorflow/lite/toco/tflite/simple_operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
+#include "tensorflow/lite/tools/versioning/op_version.h"
namespace toco {
@@ -37,6 +40,49 @@
// LINT.IfChange
+::tflite::TensorType GetTensorType(const ArrayDataType type) {
+ const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
+ {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
+ {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
+ {ArrayDataType::kInt8, ::tflite::TensorType_INT8},
+ {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
+ {ArrayDataType::kInt16, ::tflite::TensorType_INT16},
+ {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
+ {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
+ {ArrayDataType::kString, ::tflite::TensorType_STRING},
+ {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
+ {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}};
+
+ auto it = tensor_type_map.find(type);
+ if (it != tensor_type_map.end()) {
+ return it->second;
+ }
+ return static_cast<::tflite::TensorType>(-1);
+}
+
+::tflite::OpSignature GetVersioningOpSig(
+ const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
+ std::vector<::tflite::TensorType> input_types, output_types;
+ for (auto input_name : op_signature.op->inputs) {
+ ::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
+ if (op_signature.model->HasArray(input_name)) {
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ input_type = GetTensorType(input_array.data_type);
+ }
+ input_types.push_back(input_type);
+ }
+ for (auto output_name : op_signature.op->outputs) {
+ ::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
+ if (op_signature.model->HasArray(output_name)) {
+ const Array& output_array = op_signature.model->GetArray(output_name);
+ output_type = GetTensorType(output_array.data_type);
+ }
+ output_types.push_back(output_type);
+ }
+ return ::tflite::OpSignature{
+ .op = op, .input_types = input_types, .output_types = output_types};
+}
+
class AveragePool
: public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
@@ -64,15 +110,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Convolution
@@ -103,29 +140,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const string& filter_name = op_signature.op->inputs[1];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& filter_array = op_signature.model->GetArray(filter_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- // If the op has signed int8 inputs and outputs, its version 3.
- if (input_array.data_type == ArrayDataType::kInt8 &&
- filter_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kInt8) {
- return 3;
- }
- // If the op is a signed int8 hybrid operation, we need to return
- // version 2.
- if (input_array.data_type == ArrayDataType::kFloat &&
- filter_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kFloat) {
- return 2;
- }
- return 1;
- }
};
class DepthwiseConvolution
@@ -162,23 +176,13 @@
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& conv_op =
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
- const string& input_name = op_signature.op->inputs[0];
- const string& filter_name = op_signature.op->inputs[1];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& filter_array = op_signature.model->GetArray(filter_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- // If the op has signed int8 inputs and outputs, its version 3.
- if (input_array.data_type == ArrayDataType::kInt8 &&
- filter_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kInt8) {
- return 3;
- }
- if (conv_op.dilation_width_factor != 1 ||
- conv_op.dilation_height_factor != 1) {
- return 2;
- }
- return 1;
+ ::tflite::OpSignature op_sig =
+ GetVersioningOpSig(builtin_op(), op_signature);
+ op_sig.options.depthwise_conv_2d.dilation_w_factor =
+ conv_op.dilation_width_factor;
+ op_sig.options.depthwise_conv_2d.dilation_h_factor =
+ conv_op.dilation_height_factor;
+ return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
@@ -200,16 +204,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
@@ -225,10 +219,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class SpaceToBatchND
@@ -246,16 +236,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
@@ -276,16 +256,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
@@ -306,10 +276,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class BatchToSpaceND
@@ -327,16 +293,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
@@ -356,10 +312,6 @@
op->src_data_type = DataType::Deserialize(options.in_data_type());
op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class Concatenation
@@ -378,16 +330,6 @@
TocoOperator* op) const override {
op->axis = options.axis();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class DepthToSpace
@@ -406,10 +348,6 @@
TocoOperator* op) const override {
op->block_size = options.block_size();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class FakeQuant
@@ -434,7 +372,10 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
- return fq_op.narrow_range ? 2 : 1;
+ ::tflite::OpSignature op_sig =
+ GetVersioningOpSig(builtin_op(), op_signature);
+ op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
+ return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
@@ -444,28 +385,27 @@
::tflite::BuiltinOptions_FullyConnectedOptions> {
public:
using BuiltinOperator::BuiltinOperator;
+
+ ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
+ FullyConnectedWeightsFormat fmt) const {
+ switch (fmt) {
+ case FullyConnectedWeightsFormat::kDefault:
+ return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case FullyConnectedWeightsFormat::kShuffled4x16Int8:
+ return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ default:
+ LOG(ERROR) << "Unhandled FC weights format";
+ return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ }
+ }
+
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
- ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
- switch (op.weights_format) {
- case FullyConnectedWeightsFormat::kDefault:
- tflite_weights_format =
- ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
- break;
- case FullyConnectedWeightsFormat::kShuffled4x16Int8:
- tflite_weights_format =
- ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
- break;
- default:
- LOG(ERROR) << "Unhandled FC weights format";
- tflite_weights_format =
- ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
- }
- return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
- tflite_weights_format);
+ return ::tflite::CreateFullyConnectedOptions(
+ *builder, activation_function, GetWeightFormat(op.weights_format));
}
void ReadOptions(const TfLiteOptions& options,
@@ -485,53 +425,15 @@
}
}
- // +-----------------+--------------------+--------------------------+
- // | | Weight::Default | Weight::Shuffled4x16Int8 |
- // +-----------------+--------------------+--------------------------+
- // | Float | 1 | 2 |
- // | Quantized Uint8 | 1 | 2 |
- // | Hybrid | 3 | 3 |
- // | Quantized Int8 | 4 | 4 |
- // +-----------------+--------------------+--------------------------+
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fc_op =
static_cast<const FullyConnectedOperator&>(*op_signature.op);
- const string& input_name = op_signature.op->inputs[0];
- const string& weights_name = op_signature.op->inputs[1];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& weights_array = op_signature.model->GetArray(weights_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- // 2 inputs (no bias) use case is supported starting from version 6.
- if (op_signature.op->inputs.size() == 2) {
- return 6;
- }
- // `keep_num_dims` is supported at verison 5.
- if (fc_op.keep_num_dims) {
- return 5;
- }
- // Int8 fully fixed point kernel is at version 4.
- if (input_array.data_type == ArrayDataType::kInt8 &&
- weights_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kInt8) {
- return 4;
- }
- // If the op is a signed int8 hybrid operation, we need to return
- // version 3.
- if (input_array.data_type == ArrayDataType::kFloat &&
- weights_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kFloat) {
- return 3;
- }
- // For float and uint8 fixed point kernels, if the weight is
- // Shuffled4x16Int8, is is version 2.
- if (fc_op.weights_format ==
- FullyConnectedWeightsFormat::kShuffled4x16Int8) {
- return 2;
- }
-
- // Otherwise (weight is default), the version is 1.
- return 1;
+ ::tflite::OpSignature op_sig =
+ GetVersioningOpSig(builtin_op(), op_signature);
+ op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
+ op_sig.options.fully_connected.weights_format =
+ GetWeightFormat(fc_op.weights_format);
+ return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
@@ -550,20 +452,6 @@
TocoOperator* op) const override {
op->axis = {options.axis()};
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op takes bool input, it is version 3.
- if (input_array.data_type == ArrayDataType::kBool) {
- return 3;
- }
- // If the op takes int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class GatherNd
@@ -580,10 +468,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
@@ -604,24 +488,6 @@
ActivationFunction::Deserialize(options.fused_activation_function());
op->rank = options.rank();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const string& weights_feature_name = op_signature.op->inputs[1];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& weights_feature_array =
- op_signature.model->GetArray(weights_feature_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- // If the op is a signed int8 hybrid operation, we need to return
- // version 2.
- if (input_array.data_type == ArrayDataType::kFloat &&
- weights_feature_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kFloat) {
- return 2;
- }
- return 1;
- }
};
class L2Normalization
@@ -642,16 +508,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& output_name = op_signature.op->outputs[0];
- const Array& output_array = op_signature.model->GetArray(output_name);
- // Version 2 supports signed int8 input types.
- if (output_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
@@ -679,10 +535,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class LocalResponseNormalization
@@ -706,10 +558,6 @@
op->alpha = options.alpha();
op->beta = options.beta();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
@@ -737,43 +585,6 @@
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class Maximum : public SimpleOperator<TensorFlowMaximumOperator> {
- public:
- explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class Minimum : public SimpleOperator<TensorFlowMinimumOperator> {
- public:
- explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
@@ -805,17 +616,15 @@
const auto& input1_quant = input1_array.quantization_params;
const auto& input2_quant = input2_array.quantization_params;
const auto& output_quant = output_array.quantization_params;
- // Version 3 supports have a rescale value greater than or equal to 1.
- if (input1_quant && input2_quant && output_quant &&
- (input1_quant->scale * input2_quant->scale / output_quant->scale) >=
- 1.0) {
- return 3;
- }
- // Version 2 supports signed int8 input types.
- if (input1_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
+ const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
+ const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
+ const float output_scale = output_quant ? output_quant->scale : 0.0f;
+ ::tflite::OpSignature op_sig =
+ GetVersioningOpSig(builtin_op(), op_signature);
+ op_sig.options.mul.input1_scale = input1_scale;
+ op_sig.options.mul.input2_scale = input2_scale;
+ op_sig.options.mul.output_scale = output_scale;
+ return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
@@ -832,16 +641,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Tile
@@ -857,9 +656,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
@@ -875,16 +671,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Reshape
@@ -906,10 +692,6 @@
op->shape.insert(op->shape.end(), options.new_shape()->begin(),
options.new_shape()->end());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class Softmax
@@ -927,15 +709,6 @@
TocoOperator* op) const override {
op->beta = options.beta();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class SpaceToDepth
@@ -954,16 +727,6 @@
TocoOperator* op) const override {
op->block_size = options.block_size();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Transpose
@@ -979,40 +742,32 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op takes bool input, it is version 3.
- if (input_array.data_type == ArrayDataType::kBool) {
- return 3;
- }
- // If the op takes int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
::tflite::BuiltinOptions_LSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
+
+ ::tflite::LSTMKernelType GetKernelType(
+ LstmCellOperator::KernelType type) const {
+ switch (type) {
+ case LstmCellOperator::KERNEL_BASIC:
+ return ::tflite::LSTMKernelType_BASIC;
+ break;
+ case LstmCellOperator::KERNEL_FULL:
+ return ::tflite::LSTMKernelType_FULL;
+ break;
+ default:
+ LOG(ERROR) << "Unhandled Kernel Type";
+ return static_cast<::tflite::LSTMKernelType>(-1);
+ }
+ }
+
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL;
- switch (op.kernel_type) {
- case LstmCellOperator::KERNEL_BASIC:
- kernel_type = ::tflite::LSTMKernelType_BASIC;
- break;
- case LstmCellOperator::KERNEL_FULL:
- kernel_type = ::tflite::LSTMKernelType_FULL;
- break;
- default:
- return -1;
- }
+ ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
@@ -1040,27 +795,10 @@
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& lstm_op =
static_cast<const LstmCellOperator&>(*op_signature.op);
- switch (lstm_op.kernel_type) {
- case LstmCellOperator::KERNEL_FULL: {
- // If the input tensor is float and a weight is int8, this is a version
- // 3 hybrid operation.
- const string& input_name = op_signature.op->inputs[0];
- const string& weights_name = op_signature.op->inputs[2];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& weights_array = op_signature.model->GetArray(weights_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- if (input_array.data_type == ArrayDataType::kFloat &&
- weights_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kFloat) {
- return 3;
- }
- return 1;
- }
- case LstmCellOperator::KERNEL_BASIC:
- // KERNEL_BASIC was added in version 2.
- return 2;
- }
+ ::tflite::OpSignature op_sig =
+ GetVersioningOpSig(builtin_op(), op_signature);
+ op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
+ return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
std::vector<bool> GetMutatingInputVariables(
@@ -1110,23 +848,6 @@
::tflite::ActivationFunctionType_TANH);
}
- int GetVersion(const OperatorSignature& op_signature) const override {
- // If the input tensor is float and a weight is int8, this is a version
- // 2 hybrid operation.
- const string& input_name = op_signature.op->inputs[0];
- const string& weights_name = op_signature.op->inputs[2];
- const string& output_name = op_signature.op->outputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- const Array& weights_array = op_signature.model->GetArray(weights_name);
- const Array& output_array = op_signature.model->GetArray(output_name);
- if (input_array.data_type == ArrayDataType::kFloat &&
- weights_array.data_type == ArrayDataType::kInt8 &&
- output_array.data_type == ArrayDataType::kFloat) {
- return 2;
- }
- return 1;
- }
-
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
@@ -1164,10 +885,6 @@
op->merge_outputs = options.merge_outputs();
}
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
@@ -1209,10 +926,6 @@
op->merge_outputs = options.merge_outputs();
}
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
@@ -1238,16 +951,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Sum
@@ -1265,15 +968,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ReduceMax
@@ -1291,16 +985,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ReduceMin
@@ -1318,16 +1002,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ReduceProd
@@ -1345,10 +1019,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class ReduceAny
@@ -1366,24 +1036,6 @@
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-};
-
-class Relu6 : public SimpleOperator<Relu6Operator> {
- public:
- explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ResizeBilinear
@@ -1402,16 +1054,6 @@
TocoOperator* op) const override {
op->align_corners = options.align_corners();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op takes int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ResizeNearestNeighbor
@@ -1431,16 +1073,6 @@
TocoOperator* op) const override {
op->align_corners = options.align_corners();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Squeeze
@@ -1462,10 +1094,6 @@
options.squeeze_dims()->begin(),
options.squeeze_dims()->end());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class Split
@@ -1484,18 +1112,6 @@
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2, for int32 it's version 3.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- } else if (input_array.data_type == ArrayDataType::kInt32) {
- return 3;
- }
- return 1;
- }
};
class SplitV
@@ -1514,10 +1130,6 @@
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class StridedSlice
@@ -1542,16 +1154,6 @@
op->new_axis_mask = options.new_axis_mask();
op->shrink_axis_mask = options.shrink_axis_mask();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
@@ -1566,15 +1168,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
@@ -1592,16 +1185,6 @@
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
-
- return 1;
- }
};
class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
@@ -1619,16 +1202,6 @@
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
-
- return 1;
- }
};
class TransposeConv
@@ -1652,10 +1225,6 @@
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class SparseToDense
@@ -1675,22 +1244,6 @@
TocoOperator* op) const override {
op->validate_indices = options.validate_indices();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& value_input_name = op_signature.op->inputs[2];
- const Array& value_input_array =
- op_signature.model->GetArray(value_input_name);
- // Version 3 supports Int8 and Uint8 type.
- if (value_input_array.data_type == ArrayDataType::kInt8 ||
- value_input_array.data_type == ArrayDataType::kUint8) {
- return 3;
- }
- // Version 2 supports Int64 value type.
- if (value_input_array.data_type == ArrayDataType::kInt64) {
- return 2;
- }
- return 1;
- }
};
class ExpandDims
@@ -1707,10 +1260,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
@@ -1729,16 +1278,6 @@
op->values_count = options.values_count();
op->axis = options.axis();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class Shape
@@ -1757,42 +1296,6 @@
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.out_type());
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-};
-
-class Slice : public SimpleOperator<SliceOperator> {
- public:
- explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- if (input_array.data_type == ArrayDataType::kInt8) {
- // Version 2 supports signed int8 input types.
- return 2;
- }
- if (input_array.data_type == ArrayDataType::kString) {
- // Version 3 supports string input types.
- return 3;
- }
- return 1;
- }
-};
-
-class Tanh : public SimpleOperator<TanhOperator> {
- public:
- explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
@@ -1808,10 +1311,6 @@
TocoOperator* op) const override {
op->axis = options.axis();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class CTCBeamSearchDecoder
@@ -1851,17 +1350,6 @@
op->num = options.num();
op->axis = options.axis();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // If the op take int8/uint8 input, it is version 2.
- if (input_array.data_type == ArrayDataType::kInt8 ||
- input_array.data_type == ArrayDataType::kUint8) {
- return 2;
- }
- return 1;
- }
};
class LeakyRelu
@@ -1878,39 +1366,6 @@
TocoOperator* op) const override {
op->alpha = options.alpha();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-};
-
-class Logistic : public SimpleOperator<LogisticOperator> {
- public:
- explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
- public:
- explicit LogSoftmax()
- : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class SquaredDifference
@@ -1928,10 +1383,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class MirrorPad
@@ -1953,8 +1404,6 @@
? MirrorPadMode::kReflect
: MirrorPadMode::kSymmetric;
}
-
- int GetVersion(const OperatorSignature& op) const override { return 1; }
};
class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
@@ -1978,10 +1427,6 @@
? toco::ArrayDataType::kInt64
: toco::ArrayDataType::kInt32;
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
class UnidirectionalSequenceRnn
@@ -2005,10 +1450,6 @@
::tflite::ActivationFunctionType_TANH);
}
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
-
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
@@ -2030,10 +1471,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
@@ -2264,21 +1701,6 @@
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 3 supports signed int16 input types.
- if (input_array.data_type == ArrayDataType::kInt16 ||
- input_array.data_type == ArrayDataType::kFloat16) {
- return 3;
- }
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
};
class ReverseSequence
@@ -2300,127 +1722,8 @@
op->seq_dim = options.seq_dim();
op->batch_dim = options.batch_dim();
}
-
- int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
- }
};
-class Equal : public SimpleOperator<TensorFlowEqualOperator> {
- public:
- explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class NotEqual : public SimpleOperator<TensorFlowNotEqualOperator> {
- public:
- explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class Greater : public SimpleOperator<TensorFlowGreaterOperator> {
- public:
- explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class GreaterEqual : public SimpleOperator<TensorFlowGreaterEqualOperator> {
- public:
- explicit GreaterEqual()
- : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class Less : public SimpleOperator<TensorFlowLessOperator> {
- public:
- explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
- public:
- explicit LessEqual()
- : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class Select : public SimpleOperator<SelectOperator> {
- public:
- explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports signed int8 input types.
- if (input_array.data_type == ArrayDataType::kInt8) {
- return 2;
- }
- return 1;
- }
-};
-
-class FloorDiv : public SimpleOperator<FloorDivOperator> {
- public:
- explicit FloorDiv() : SimpleOperator("FLOOR_DIV", OperatorType::kFloorDiv) {}
- int GetVersion(const OperatorSignature& op_signature) const override {
- const string& input_name = op_signature.op->inputs[0];
- const Array& input_array = op_signature.model->GetArray(input_name);
- // Version 2 supports float input types.
- if (input_array.data_type == ArrayDataType::kFloat) {
- return 2;
- }
- return 1;
- }
-};
-
-
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
@@ -2570,9 +1873,9 @@
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
- "MATRIX_DIAG", OperatorType::kMatrixDiag));
+ ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
- "MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag));
+ ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
// Custom Operators.
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
@@ -2585,81 +1888,96 @@
// when custom ops are exported but SimpleOperator bypasses thoses. To
// prevent user confusion we are settling on using SimpleOperator only for
// builtins.
- ops.push_back(
- MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
- ops.push_back(
- MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
- ops.push_back(
- MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
- ops.push_back(
- MakeUnique<SimpleOperator<RoundOperator>>("ROUND", OperatorType::kRound));
- ops.push_back(
- MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
+ ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
+ ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
+ ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
+ ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
+ ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
+ ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
+ ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
+ ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
+ ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
+ ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
- "RELU_N1_TO_1", OperatorType::kRelu1));
- ops.push_back(MakeUnique<Relu6>());
- ops.push_back(
- MakeUnique<SimpleOperator<PReluOperator>>("PRELU", OperatorType::kPRelu));
- ops.push_back(MakeUnique<Logistic>());
- ops.push_back(MakeUnique<Tanh>());
- ops.push_back(
- MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
- ops.push_back(
- MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
- ops.push_back(MakeUnique<LogSoftmax>());
- ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum
- ops.push_back(MakeUnique<Minimum>()); // Element-wise Minimum
- ops.push_back(MakeUnique<Greater>());
- ops.push_back(MakeUnique<GreaterEqual>());
- ops.push_back(MakeUnique<Less>());
- ops.push_back(MakeUnique<LessEqual>());
- ops.push_back(MakeUnique<Equal>());
- ops.push_back(MakeUnique<NotEqual>());
- ops.push_back(
- MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
- ops.push_back(MakeUnique<Select>());
- ops.push_back(MakeUnique<Slice>());
- ops.push_back(
- MakeUnique<SimpleOperator<PowOperator>>("POW", OperatorType::kPow));
+ ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
+ ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
+ ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
+ ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
+ ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
+ ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
+ ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
+ ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
+ ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
+ ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
+ ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
+ ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
+ ::tflite::BuiltinOperator_COS, OperatorType::kCos));
+ ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
+ ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
+ ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
+ ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
+ ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
+ ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
+ ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
+ ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
+ ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
+ ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
+ ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
+ ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
+ ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
+ ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
+ ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
+ ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
+ ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
+ ::tflite::BuiltinOperator_POW, OperatorType::kPow));
ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
- "LOGICAL_OR", OperatorType::kLogicalOr));
+ ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
- "LOGICAL_AND", OperatorType::kLogicalAnd));
+ ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
- "LOGICAL_NOT", OperatorType::kLogicalNot));
- ops.push_back(MakeUnique<FloorDiv>());
+ ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
+ ops.emplace_back(new SimpleOperator<FloorDivOperator>(
+ ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
ops.emplace_back(new SimpleOperator<FloorModOperator>(
- "FLOOR_MOD", OperatorType::kFloorMod));
- ops.emplace_back(
- new SimpleOperator<RangeOperator>("RANGE", OperatorType::kRange));
+ ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
+ ops.emplace_back(new SimpleOperator<RangeOperator>(
+ ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
// Element-wise operator
- ops.push_back(
- MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
- ops.push_back(
- MakeUnique<SimpleOperator<LogOperator>>("LOG", OperatorType::kLog));
+ ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
+ ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
+ ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
+ ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
- "SQRT", OperatorType::kSqrt));
+ ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
- "RSQRT", OperatorType::kRsqrt));
+ ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
- "SQUARE", OperatorType::kSquare));
+ ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
- "ZEROS_LIKE", OperatorType::kZerosLike));
- ops.push_back(
- MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
+ ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
+ ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
+ ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
- "HARD_SWISH", OperatorType::kHardSwish));
- ops.push_back(
- MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
+ ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
+ ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
+ ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
- "REVERSE_V2", OperatorType::kReverseV2));
+ ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
- "RANK", OperatorType::kRank));
+ ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
return ops;
}
} // namespace
-// LINT.ThenChange(//tensorflow/lite/toco/tflite/op_version.cc)
+// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool enable_select_tf_ops) {
diff --git a/tensorflow/lite/toco/tflite/operator.h b/tensorflow/lite/toco/tflite/operator.h
index 899db1a..19d9214 100644
--- a/tensorflow/lite/toco/tflite/operator.h
+++ b/tensorflow/lite/toco/tflite/operator.h
@@ -19,6 +19,7 @@
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/tools/versioning/op_version.h"
namespace toco {
@@ -93,8 +94,9 @@
// * The first version for each op should be 1 (to be consistent with the
// default value in Flatbuffer. `return 1;` is okay for newly implemented
// ops.
- // * When multiple versions are defined for an op, this function needs to be
- // overridden. (See example in `operator_test.cc`)
+ // * When multiple versions are defined for an op, this function could be
+ // overridden. (See example in `operator_test.cc` and
+ // 'tools/versioning/op_version.cc`)
virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
// Given a Toco `Operator`, return a list of booleans indicating the op
@@ -113,6 +115,11 @@
OperatorType type_;
};
+// Helper function to create ::tflite::OpSignature from the given
+// ::tflite::BuiltinOperator and OperatorSignature.
+::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op,
+ const OperatorSignature& op_signature);
+
// Helper function to determine if a unsupported TensorFlow op should be
// exported as an Flex op or a regular custom op.
bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
diff --git a/tensorflow/lite/toco/tflite/simple_operator.h b/tensorflow/lite/toco/tflite/simple_operator.h
index 2900748..150b0d0 100644
--- a/tensorflow/lite/toco/tflite/simple_operator.h
+++ b/tensorflow/lite/toco/tflite/simple_operator.h
@@ -32,6 +32,11 @@
class SimpleOperator : public BaseOperator {
public:
using BaseOperator::BaseOperator;
+
+ SimpleOperator(::tflite::BuiltinOperator op, OperatorType type)
+ : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
+ builtin_op_(op) {}
+
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return Options();
@@ -43,8 +48,14 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
- return 1;
+ return ::tflite::GetBuiltinOperatorVersion(
+ GetVersioningOpSig(builtin_op_, op_signature));
}
+
+ ::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
+
+ private:
+ const ::tflite::BuiltinOperator builtin_op_;
};
} // namespace tflite
diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD
new file mode 100644
index 0000000..f59cfba
--- /dev/null
+++ b/tensorflow/lite/tools/versioning/BUILD
@@ -0,0 +1,33 @@
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "op_version",
+ srcs = ["op_version.cc"],
+ hdrs = [
+ "op_version.h",
+ ],
+ deps = [
+ "//tensorflow/lite/schema:schema_fbs",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "op_version_test",
+ srcs = ["op_version_test.cc"],
+ deps = [
+ ":op_version",
+ "//tensorflow/lite/schema:schema_fbs",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
new file mode 100644
index 0000000..b7920d2
--- /dev/null
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -0,0 +1,283 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/tools/versioning/op_version.h"
+
+#include <cstring>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+
+namespace tflite {
+
+int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
+ switch (op_sig.op) {
+ case BuiltinOperator_CONV_2D:
+ // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
+ // version 3.
+ if (op_sig.input_types.at(0) == TensorType_INT8 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_INT8) {
+ return 3;
+ }
+ // If the op is a signed int8 hybrid operation, we need to return
+ // version 2.
+ if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_FLOAT32) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_DEPTHWISE_CONV_2D:
+ // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
+ // version 3.
+ if (op_sig.input_types.at(0) == TensorType_INT8 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_INT8) {
+ return 3;
+ }
+ if (op_sig.options.depthwise_conv_2d.dilation_w_factor != 1 ||
+ op_sig.options.depthwise_conv_2d.dilation_h_factor != 1) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_FAKE_QUANT:
+ if (op_sig.options.fakequant.narrow_range) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_FULLY_CONNECTED:
+ // +-----------------+--------------------+--------------------------+
+ // | | Weight::Default | Weight::Shuffled4x16Int8 |
+ // +-----------------+--------------------+--------------------------+
+ // | Float | 1 | 2 |
+ // | Quantized Uint8 | 1 | 2 |
+ // | Hybrid | 3 | 3 |
+ // | Quantized Int8 | 4 | 4 |
+ // +-----------------+--------------------+--------------------------+
+ // 2 op_sig.inputs (no bias) use case is supported starting from
+ // version 6.
+ if (op_sig.input_types.size() == 2) {
+ return 6;
+ }
+ // `keep_num_dims` is supported at verison 5.
+ if (op_sig.options.fully_connected.keep_num_dims) {
+ return 5;
+ }
+ // Int8 fully fixed point kernel is at version 4.
+ if (op_sig.input_types.at(0) == TensorType_INT8 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_INT8) {
+ return 4;
+ }
+ // If the op is a signed int8 hybrid operation, we need to return
+ // version 3.
+ if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_FLOAT32) {
+ return 3;
+ }
+ // For float and uint8 fixed point kernels, if the weight is
+ // Shuffled4x16Int8, is is version 2.
+ if (op_sig.options.fully_connected.weights_format ==
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) {
+ return 2;
+ }
+ // Otherwise (weight is default), the version is 1.
+ return 1;
+
+ case BuiltinOperator_GATHER:
+ // If the op takes bool input, it is version 3.
+ if (op_sig.input_types.at(0) == TensorType_BOOL) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_SVDF:
+ // If the op is a signed int8 hybrid operation, we need to return
+ // version 2.
+ if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
+ op_sig.input_types.at(1) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_FLOAT32) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_MUL:
+ // Version 3 supports have a rescale value greater than or equal to 1.
+ if (op_sig.options.mul.input1_scale != 0 &&
+ op_sig.options.mul.input2_scale != 0 &&
+ op_sig.options.mul.output_scale != 0 &&
+ (op_sig.options.mul.input1_scale * op_sig.options.mul.input2_scale /
+ op_sig.options.mul.output_scale) >= 1.0) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_TRANSPOSE:
+ // If the op takes bool input, it is version 3.
+ if (op_sig.input_types.at(0) == TensorType_BOOL) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_LSTM:
+ // If the input tensor is float and a weight is int8, this is a version
+ // 3 hybrid operation.
+ if (op_sig.options.lstm.kernel_type == LSTMKernelType_FULL &&
+ op_sig.input_types.at(0) == TensorType_FLOAT32 &&
+ op_sig.input_types.at(2) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_FLOAT32) {
+ return 3;
+ }
+ // KERNEL_BASIC was added in version 2.
+ if (op_sig.options.lstm.kernel_type == LSTMKernelType_BASIC) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ // If the input tensor is float and a weight is int8, this is a version
+ // 2 hybrid operation.
+ if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
+ op_sig.input_types.at(2) == TensorType_INT8 &&
+ op_sig.output_types.at(0) == TensorType_FLOAT32) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_SPLIT:
+ // If the op take int8 input, it is version 2, for int32 it's version 3.
+ if (op_sig.input_types.at(0) == TensorType_INT32) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_SPARSE_TO_DENSE:
+ // Version 3 supports Int8 and Uint8 type.
+ if (op_sig.input_types.at(2) == TensorType_INT8 ||
+ op_sig.input_types.at(2) == TensorType_UINT8) {
+ return 3;
+ }
+ // Version 2 supports Int64 value type.
+ if (op_sig.input_types.at(2) == TensorType_INT64) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_SLICE:
+ // Version 3 supports string input types.
+ if (op_sig.input_types.at(0) == TensorType_STRING) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_UNPACK:
+ // If the op take int8/uint8 input, it is version 2.
+ if (op_sig.input_types.at(0) == TensorType_INT8 ||
+ op_sig.input_types.at(0) == TensorType_UINT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_DEQUANTIZE:
+ // Version 3 supports signed int16 input types.
+ if (op_sig.input_types.at(0) == TensorType_INT16 ||
+ op_sig.input_types.at(0) == TensorType_FLOAT16) {
+ return 3;
+ }
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_FLOOR_DIV:
+ if (op_sig.input_types.at(0) == TensorType_FLOAT32) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_L2_NORMALIZATION:
+ if (op_sig.output_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_ADD:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SUB:
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ case BuiltinOperator_CONCATENATION:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_SOFTMAX:
+ case BuiltinOperator_SPACE_TO_DEPTH:
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_SUM:
+ case BuiltinOperator_REDUCE_MAX:
+ case BuiltinOperator_REDUCE_MIN:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RESIZE_BILINEAR:
+ case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
+ case BuiltinOperator_PACK:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_STRIDED_SLICE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_ARG_MAX:
+ case BuiltinOperator_ARG_MIN:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_SELECT:
+ if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
+
+ default:
+ return 1;
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h
new file mode 100644
index 0000000..b653896
--- /dev/null
+++ b/tensorflow/lite/tools/versioning/op_version.h
@@ -0,0 +1,57 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
+#define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
+
+#include <vector>
+
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// OpSignature contains operator parameters for version functions.
+typedef struct {
+ BuiltinOperator op;
+ std::vector<TensorType> input_types;
+ std::vector<TensorType> output_types;
+ union {
+ struct {
+ int32_t dilation_w_factor;
+ int32_t dilation_h_factor;
+ } depthwise_conv_2d;
+ struct {
+ bool narrow_range;
+ } fakequant;
+ struct {
+ bool keep_num_dims;
+ FullyConnectedOptionsWeightsFormat weights_format;
+ } fully_connected;
+ struct {
+ float input1_scale;
+ float input2_scale;
+ float output_scale;
+ } mul;
+ struct {
+ LSTMKernelType kernel_type;
+ } lstm;
+ } options;
+} OpSignature;
+
+// Returns version of builtin ops by the given signature.
+int GetBuiltinOperatorVersion(const OpSignature& op_sig);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
new file mode 100644
index 0000000..eeadd68
--- /dev/null
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -0,0 +1,334 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/tools/versioning/op_version.h"
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+TEST(OpVersionTest, VersioningSpareToDense) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_SPARSE_TO_DENSE,
+ .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8,
+ TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_SPARSE_TO_DENSE,
+ .input_types = std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8,
+ TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_SPARSE_TO_DENSE,
+ .input_types = std::vector<TensorType>{TensorType_INT64, TensorType_INT64,
+ TensorType_INT64},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_SPARSE_TO_DENSE,
+ .input_types = std::vector<TensorType>{TensorType_INT32, TensorType_INT32,
+ TensorType_INT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+// Test version for a simple Op with 2 versions and the input type controls the
+// version.
+void SimpleVersioningTest(BuiltinOperator op) {
+ OpSignature fake_op_sig = {
+ .op = op,
+ .input_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = op,
+ .input_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+// Test version for a simple Op with 2 versions and the output type controls the
+void SimpleOutputVersioningTest(BuiltinOperator op) {
+ OpSignature fake_op_sig = {
+ .op = op,
+ .input_types = std::vector<TensorType>{},
+ .output_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = op,
+ .input_types = std::vector<TensorType>{},
+ .output_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+TEST(OpVersionTest, VersioningEqualTest) {
+ SimpleVersioningTest(BuiltinOperator_EQUAL);
+}
+
+TEST(OpVersionTest, VersioningNotEqualTest) {
+ SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
+}
+
+TEST(OpVersionTest, VersioningLessTest) {
+ SimpleVersioningTest(BuiltinOperator_LESS);
+}
+
+TEST(OpVersionTest, VersioningLessEqualTest) {
+ SimpleVersioningTest(BuiltinOperator_LESS_EQUAL);
+}
+
+TEST(OpVersionTest, VersioningGreaterTest) {
+ SimpleVersioningTest(BuiltinOperator_GREATER);
+}
+
+TEST(OpVersionTest, VersioningGreaterEqualTest) {
+ SimpleVersioningTest(BuiltinOperator_GREATER_EQUAL);
+}
+
+TEST(OpVersionTest, VersioningSpaceToBatchNDTest) {
+ SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
+}
+
+TEST(OpVersionTest, VersioningLogSoftmaxTest) {
+ SimpleVersioningTest(BuiltinOperator_LOG_SOFTMAX);
+}
+
+TEST(OpVersionTest, VersioningPackTest) {
+ SimpleVersioningTest(BuiltinOperator_PACK);
+}
+
+TEST(OpVersionTest, VersioningUnpackTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_UNPACK,
+ .input_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_UNPACK,
+ .input_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_UNPACK,
+ .input_types = std::vector<TensorType>{TensorType_INT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+TEST(OpVersionTest, VersioningBatchToSpaceNDTest) {
+ SimpleVersioningTest(BuiltinOperator_BATCH_TO_SPACE_ND);
+}
+
+TEST(OpVersionTest, VersioningTanhTest) {
+ SimpleVersioningTest(BuiltinOperator_TANH);
+}
+
+TEST(OpVersionTest, VersioningStridedSliceTest) {
+ SimpleVersioningTest(BuiltinOperator_STRIDED_SLICE);
+}
+
+TEST(OpVersionTest, VersioningSpaceToDepthTest) {
+ SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
+}
+
+TEST(OpVersionTest, VersioningSliceTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_SLICE,
+ .input_types = std::vector<TensorType>{TensorType_STRING},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_SLICE,
+ .input_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_SLICE,
+ .input_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+TEST(OpVersionTest, VersioningLogisticTest) {
+ SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
+}
+
+TEST(OpVersionTest, VersioningL2NormTest) {
+ SimpleOutputVersioningTest(BuiltinOperator_L2_NORMALIZATION);
+}
+
+TEST(OpVersionTest, VersioningMaxTest) {
+ SimpleVersioningTest(BuiltinOperator_MAXIMUM);
+}
+
+TEST(OpVersionTest, VersioningMinTest) {
+ SimpleVersioningTest(BuiltinOperator_MINIMUM);
+}
+
+TEST(OpVersionTest, VersioningMeanTest) {
+ SimpleVersioningTest(BuiltinOperator_MEAN);
+}
+
+TEST(OpVersionTest, VersioningSumTest) {
+ SimpleVersioningTest(BuiltinOperator_SUM);
+}
+
+TEST(OpVersionTest, VersioningAddTest) {
+ SimpleVersioningTest(BuiltinOperator_ADD);
+}
+
+TEST(OpVersionTest, VersioningSubTest) {
+ SimpleVersioningTest(BuiltinOperator_SUB);
+}
+
+void SimpleMulVersioningTest(TensorType data_type, float multiplier,
+ int version) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_MUL,
+ .input_types = std::vector<TensorType>{data_type, data_type},
+ .output_types = std::vector<TensorType>{data_type},
+ };
+ fake_op_sig.options.mul = {1.0f, 1.0f, 1.0f / multiplier};
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), version);
+}
+
+TEST(OpVersionTest, VersioningMulTest) {
+ SimpleMulVersioningTest(TensorType_UINT8, 0.5f, 1);
+ SimpleMulVersioningTest(TensorType_INT8, 0.5f, 2);
+ SimpleMulVersioningTest(TensorType_INT8, 2.0f, 3);
+}
+
+TEST(OpVersionTest, VersioningPadTest) {
+ SimpleVersioningTest(BuiltinOperator_PAD);
+}
+
+TEST(OpVersionTest, VersioningPadV2Test) {
+ SimpleVersioningTest(BuiltinOperator_PADV2);
+}
+
+TEST(OpVersionTest, VersioningConcatenationTest) {
+ SimpleVersioningTest(BuiltinOperator_CONCATENATION);
+}
+
+TEST(OpVersionTest, VersioningSelectTest) {
+ SimpleVersioningTest(BuiltinOperator_SELECT);
+}
+
+TEST(OpVersionTest, VersioningRelu6Test) {
+ SimpleVersioningTest(BuiltinOperator_RELU6);
+}
+
+TEST(OpVersionTest, VersioningFullyConnectedTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_FULLY_CONNECTED,
+ .input_types =
+ std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
+ .output_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ fake_op_sig.options.fully_connected = {
+ false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_FULLY_CONNECTED,
+ .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
+ .output_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ fake_op_sig.options.fully_connected = {
+ false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
+}
+
+TEST(OpVersionTest, VersioningDequantizeTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_DEQUANTIZE,
+ .input_types = std::vector<TensorType>{TensorType_INT16},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_DEQUANTIZE,
+ .input_types = std::vector<TensorType>{TensorType_FLOAT16},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_DEQUANTIZE,
+ .input_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_DEQUANTIZE,
+ .input_types = std::vector<TensorType>{TensorType_FLOAT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+}
+
+TEST(OpVersionTest, VersioningConv2DTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_CONV_2D,
+ .input_types =
+ std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
+ .output_types = std::vector<TensorType>{TensorType_UINT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_CONV_2D,
+ .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
+ .output_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_CONV_2D,
+ .input_types =
+ std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
+ .output_types = std::vector<TensorType>{TensorType_FLOAT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+}
+
+TEST(OpVersionTest, VersioningFloorDivOperatorTest) {
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_FLOOR_DIV,
+ .input_types = std::vector<TensorType>{TensorType_INT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ fake_op_sig = {
+ .op = BuiltinOperator_FLOOR_DIV,
+ .input_types = std::vector<TensorType>{TensorType_FLOAT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+}
+
+} // namespace tflite