Add ophint converted unidirectional sequence lstm opdef to graphdef_to_flatbuffer, so usl will be legalized as customized op and also add the e2e test.
PiperOrigin-RevId: 297024636
Change-Id: Ic2cec9ab899ad1d1ea2ade830b8bee07bc4b9551
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index f2b89ae..98641ad 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -63,6 +63,30 @@
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
+const char kUnidirectionalSequenceLstmOp[] =
+ "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
+ "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
+ "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
+ "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
+ "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
+ "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
+ "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
+ "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
+ "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
+ "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
+ "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
+ "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
+ "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
+ "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
+ "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
+ "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
+ "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
+ "name: 'InputCellStateTensor' type: DT_FLOAT } "
+ "output_arg: { name: 'Concat' type: DT_FLOAT} "
+ "output_arg: { name: "
+ "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
+ "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
+
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
@@ -260,6 +284,7 @@
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
+ extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index 48e434a..303741b 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -258,6 +258,10 @@
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
+ # Test MLIR-Converted model.
+ result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
+ self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
+
@test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCell(self):
sess = tf.compat.v1.Session(config=CONFIG)
@@ -278,6 +282,10 @@
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
+ # Test MLIR-converted model.
+ result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, True)
+ self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
+
if __name__ == "__main__":
test.main()