Add bool support in TFL unpack op

This bool type support is required for tf.layers.keras.RNN with unroll=True.

PiperOrigin-RevId: 286955661
Change-Id: Ie557febc346e44978c4e3e96d828b21098ab3afb
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 1932122..b8b0ef6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -2359,14 +2359,14 @@
   }];
 
   let arguments = (ins
-    TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
+    TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input,
 
     I32Attr:$num,
     I32Attr:$axis
   );
 
   let results = (outs
-    Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
+    Variadic<TensorOf<[F32, I1, I8, I32, QI8, QUI8]>>:$outputs
   );
 
   let verifier = [{ return Verify(*this); }];
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 1a545b1..620f6ee 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -247,7 +247,7 @@
   AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
   AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
              /* min_version */ 1,
-             /* max_version */ 2);
+             /* max_version */ 3);
   AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
              /* min_version */ 1,
              /* max_version */ 2);
diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc
index 7de891c..8e66432 100644
--- a/tensorflow/lite/kernels/unpack.cc
+++ b/tensorflow/lite/kernels/unpack.cc
@@ -43,7 +43,8 @@
   }
   TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
   if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
-      input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
+      input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
+      input->type != kTfLiteBool) {
     context->ReportError(context, "Type '%s' is not supported by unpack.",
                          TfLiteTypeGetName(input->type));
     return kTfLiteError;
@@ -112,6 +113,10 @@
       UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
       break;
     }
+    case kTfLiteBool: {
+      UnpackImpl<bool>(context, node, input, data->num, data->axis);
+      break;
+    }
     default: {
       context->ReportError(context, "Type '%s' is not supported by unpack.",
                            TfLiteTypeGetName(input->type));
diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc
index 28d21cc..88eb706 100644
--- a/tensorflow/lite/kernels/unpack_test.cc
+++ b/tensorflow/lite/kernels/unpack_test.cc
@@ -87,43 +87,43 @@
 TEST(UnpackOpTest, FloatThreeOutputs) {
   Check<float>(/*axis=*/0, /*input_shape=*/{3, 2},
                /*input_data=*/{1, 2, 3, 4, 5, 6},
-               /*expected_output_shape=*/{{2}, {2}, {2}},
-               /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
+               /*exp_output_shape=*/{{2}, {2}, {2}},
+               /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
 }
 
 TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
   Check<float>(/*axis=*/1, /*input_shape=*/{3, 2},
                /*input_data=*/{1, 2, 3, 4, 5, 6},
-               /*expected_output_shape=*/{{3}, {3}},
-               /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
+               /*exp_output_shape=*/{{3}, {3}},
+               /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
 }
 
 TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) {
   Check<float>(/*axis=*/-1, /*input_shape=*/{3, 2},
                /*input_data=*/{1, 2, 3, 4, 5, 6},
-               /*expected_output_shape=*/{{3}, {3}},
-               /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
+               /*exp_output_shape=*/{{3}, {3}},
+               /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
 }
 
 TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) {
   Check<float>(/*axis=*/-2, /*input_shape=*/{3, 2},
                /*input_data=*/{1, 2, 3, 4, 5, 6},
-               /*expected_output_shape=*/{{2}, {2}, {2}},
-               /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
+               /*exp_output_shape=*/{{2}, {2}, {2}},
+               /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
 }
 
 TEST(UnpackOpTest, FloatOneOutput) {
   Check<float>(/*axis=*/0, /*input_shape=*/{1, 6},
                /*input_data=*/{1, 2, 3, 4, 5, 6},
-               /*expected_output_shape=*/{{6}},
-               /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}});
+               /*exp_output_shape=*/{{6}},
+               /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}});
 }
 
 TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
   Check<float>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
                /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
-               /*expected_output_shape=*/{{2, 2}, {2, 2}},
-               /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
+               /*exp_output_shape=*/{{2, 2}, {2, 2}},
+               /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
 }
 
 TEST(UnpackOpTest, FloatVectorToScalar) {
@@ -137,32 +137,32 @@
 TEST(UnpackOpTest, IntThreeOutputs) {
   Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{2}, {2}, {2}},
-                 /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+                 /*exp_output_shape=*/{{2}, {2}, {2}},
+                 /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
                  /*type=*/TensorType_INT32);
 }
 
 TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
   Check<int32_t>(/*axis=*/1, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{3}, {3}},
-                 /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+                 /*exp_output_shape=*/{{3}, {3}},
+                 /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
                  /*type=*/TensorType_INT32);
 }
 
 TEST(UnpackOpTest, IntOneOutput) {
   Check<int32_t>(/*axis=*/0, /*input_shape=*/{1, 6},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{6}},
-                 /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+                 /*exp_output_shape=*/{{6}},
+                 /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
                  /*type=*/TensorType_INT32);
 }
 
 TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
   Check<int32_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
-                 /*expected_output_shape=*/{{2, 2}, {2, 2}},
-                 /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+                 /*exp_output_shape=*/{{2, 2}, {2, 2}},
+                 /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
                  /*type=*/TensorType_INT32);
 }
 
@@ -178,48 +178,48 @@
 TEST(UnpackOpTest, Uint8ThreeOutputs) {
   Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{2}, {2}, {2}},
-                 /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+                 /*exp_output_shape=*/{{2}, {2}, {2}},
+                 /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
                  /*type=*/TensorType_UINT8);
 }
 
 TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
   Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{3}, {3}},
-                 /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+                 /*exp_output_shape=*/{{3}, {3}},
+                 /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
                  /*type=*/TensorType_UINT8);
 }
 
 TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
   Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{3}, {3}},
-                 /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+                 /*exp_output_shape=*/{{3}, {3}},
+                 /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
                  /*type=*/TensorType_UINT8);
 }
 
 TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
   Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{2}, {2}, {2}},
-                 /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+                 /*exp_output_shape=*/{{2}, {2}, {2}},
+                 /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
                  /*type=*/TensorType_UINT8);
 }
 
 TEST(UnpackOpTest, Uint8OneOutput) {
   Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
                  /*input_data=*/{1, 2, 3, 4, 5, 6},
-                 /*expected_output_shape=*/{{6}},
-                 /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+                 /*exp_output_shape=*/{{6}},
+                 /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
                  /*type=*/TensorType_UINT8);
 }
 
 TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
   Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
                  /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
-                 /*expected_output_shape=*/{{2, 2}, {2, 2}},
-                 /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+                 /*exp_output_shape=*/{{2, 2}, {2, 2}},
+                 /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
                  /*type=*/TensorType_UINT8);
 }
 
@@ -235,48 +235,48 @@
 TEST(UnpackOpTest, Int8ThreeOutputs) {
   Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
                 /*input_data=*/{1, 2, 3, 4, 5, 6},
-                /*expected_output_shape=*/{{2}, {2}, {2}},
-                /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+                /*exp_output_shape=*/{{2}, {2}, {2}},
+                /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
                 /*type=*/TensorType_INT8);
 }
 
 TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
   Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
                 /*input_data=*/{1, 2, 3, 4, 5, 6},
-                /*expected_output_shape=*/{{3}, {3}},
-                /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+                /*exp_output_shape=*/{{3}, {3}},
+                /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
                 /*type=*/TensorType_INT8);
 }
 
 TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
   Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
                 /*input_data=*/{1, 2, 3, 4, 5, 6},
-                /*expected_output_shape=*/{{3}, {3}},
-                /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+                /*exp_output_shape=*/{{3}, {3}},
+                /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
                 /*type=*/TensorType_INT8);
 }
 
 TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
   Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
                 /*input_data=*/{1, 2, 3, 4, 5, 6},
-                /*expected_output_shape=*/{{2}, {2}, {2}},
-                /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+                /*exp_output_shape=*/{{2}, {2}, {2}},
+                /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
                 /*type=*/TensorType_INT8);
 }
 
 TEST(UnpackOpTest, Int8OneOutput) {
   Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
                 /*input_data=*/{1, 2, 3, 4, 5, 6},
-                /*expected_output_shape=*/{{6}},
-                /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+                /*exp_output_shape=*/{{6}},
+                /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
                 /*type=*/TensorType_INT8);
 }
 
 TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
   Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
                 /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
-                /*expected_output_shape=*/{{2, 2}, {2, 2}},
-                /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+                /*exp_output_shape=*/{{2, 2}, {2, 2}},
+                /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
                 /*type=*/TensorType_INT8);
 }
 
@@ -288,5 +288,69 @@
                 /*type=*/TensorType_INT8);
 }
 
+// bool tests.
+TEST(UnpackOpTest, BoolThreeOutputs) {
+  Check<bool>(
+      /*axis=*/0, /*input_shape=*/{3, 2},
+      /*input_data=*/{true, false, true, false, true, false},
+      /*exp_output_shape=*/{{2}, {2}, {2}},
+      /*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsAxisOne) {
+  Check<bool>(
+      /*axis=*/1, /*input_shape=*/{3, 2},
+      /*input_data=*/{true, false, true, false, true, false},
+      /*exp_output_shape=*/{{3}, {3}},
+      /*exp_output_data=*/{{true, true, true}, {false, false, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisOne) {
+  Check<bool>(
+      /*axis=*/-1, /*input_shape=*/{3, 2},
+      /*input_data=*/{true, false, true, false, true, false},
+      /*exp_output_shape=*/{{3}, {3}},
+      /*exp_output_data=*/{{true, true, true}, {false, false, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisTwo) {
+  Check<bool>(
+      /*axis=*/-2, /*input_shape=*/{3, 2},
+      /*input_data=*/{true, false, true, false, true, false},
+      /*exp_output_shape=*/{{2}, {2}, {2}},
+      /*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolOneOutput) {
+  Check<bool>(
+      /*axis=*/0, /*input_shape=*/{1, 6},
+      /*input_data=*/{true, false, true, false, true, false},
+      /*exp_output_shape=*/{{6}},
+      /*exp_output_data=*/{{true, false, true, false, true, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
+  Check<bool>(
+      /*axis=*/2, /*input_shape=*/{2, 2, 2},
+      /*input_data=*/{true, false, true, false, true, false, true, false},
+      /*exp_output_shape=*/{{2, 2}, {2, 2}},
+      /*exp_output_data=*/
+      {{true, true, true, true}, {false, false, false, false}},
+      /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolVectorToScalar) {
+  Check<bool>(/*axis=*/0, /*input_shape=*/{5},
+              /*input_data=*/{true, false, true, false, true},
+              /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+              /*exp_output_data=*/{{true}, {false}, {true}, {false}, {true}},
+              /*type=*/TensorType_BOOL);
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/testing/op_tests/unpack.py b/tensorflow/lite/testing/op_tests/unpack.py
index c408748..0b59444 100644
--- a/tensorflow/lite/testing/op_tests/unpack.py
+++ b/tensorflow/lite/testing/op_tests/unpack.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
 import tensorflow as tf
 from tensorflow.lite.testing.zip_test_utils import create_tensor_data
 from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
@@ -31,6 +30,7 @@
   test_parameters = [{
       "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
       "axis": [0, 1, 2, 3],
+      "dtype": [tf.int32, tf.bool, tf.float32],
   }]
 
   def get_valid_axis(parameters):
@@ -43,12 +43,15 @@
 
   def build_graph(parameters):
     input_tensor = tf.compat.v1.placeholder(
-        dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
+        dtype=parameters["dtype"],
+        name=("input"),
+        shape=parameters["base_shape"])
     outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
     return [input_tensor], [outs[0]]
 
   def build_inputs(parameters, sess, inputs, outputs):
-    input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
+    input_value = create_tensor_data(
+        parameters["dtype"], shape=parameters["base_shape"])
     return [input_value], sess.run(
         outputs, feed_dict=dict(zip(inputs, [input_value])))
 
diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc
index 241048f..456d877 100644
--- a/tensorflow/lite/toco/tflite/op_version.cc
+++ b/tensorflow/lite/toco/tflite/op_version.cc
@@ -168,6 +168,8 @@
           {{OperatorType::kOneHot, 1}, "1.11.0"},
           {{OperatorType::kCTCBeamSearchDecoder, 1}, "1.11.0"},
           {{OperatorType::kUnpack, 1}, "1.11.0"},
+          {{OperatorType::kUnpack, 2}, "1.14.0"},
+          {{OperatorType::kUnpack, 3}, kPendingReleaseOpVersion},
           {{OperatorType::kLeakyRelu, 1}, "1.13.1"},
           {{OperatorType::kLogistic, 1}, "1.14.0"},
           {{OperatorType::kLogistic, 2}, "1.14.0"},
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index f98a621..f106e4c 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1349,6 +1349,21 @@
     op->num = options.num();
     op->axis = options.axis();
   }
+
+  int GetVersion(const OperatorSignature& op_signature) const override {
+    const string& input_name = op_signature.op->inputs[0];
+    const Array& input_array = op_signature.model->GetArray(input_name);
+    // If the op take int8/uint8 input, it is version 2.
+    if (input_array.data_type == ArrayDataType::kInt8 ||
+        input_array.data_type == ArrayDataType::kUint8) {
+      return 2;
+    }
+    // If the op take bool input, it is version 3.
+    if (input_array.data_type == ArrayDataType::kBool) {
+      return 3;
+    }
+    return 1;
+  }
 };
 
 class LeakyRelu
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index e638840..213e7ff 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -219,6 +219,10 @@
           op_sig.input_types.at(0) == TensorType_UINT8) {
         return 2;
       }
+      // If the op take bool input, it is version 3.
+      if (op_sig.input_types.at(0) == TensorType_BOOL) {
+        return 3;
+      }
       return 1;
 
     case BuiltinOperator_DEQUANTIZE: