Allow both Enter(data, control) as well as Enter(data).

PiperOrigin-RevId: 315584317
Change-Id: I50e8651ccbf0957a7edf1ef958aa398235867292
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index 9daebc2..3403651 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -811,11 +811,13 @@
   // fully qualified) or a short form with a single type (in which case the data
   // input and the outputs are all using this type).
   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
-    if (type.getNumInputs() != 1)
-      return parser.emitError(parser.getNameLoc())
-             << " expects a single data type";
-    result.types.assign(type.getResults().begin(), type.getResults().end());
-    types.assign(type.getInputs().begin(), type.getInputs().end());
+    // One data input, and any number of control inputs.
+    if (type.getNumInputs() >= 1) {
+      result.types.assign(type.getResults().begin(), type.getResults().end());
+      types.assign(type.getInputs().begin(), type.getInputs().end());
+    } else {
+      return parser.emitError(parser.getNameLoc()) << " expects a data input";
+    }
   } else {
     Type control_type = ControlType::get(context);
     types.append(op_infos.size() - 1, control_type);
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
index c048db5..27b8472 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
@@ -416,6 +416,17 @@
   return %0 : tensor<*xf32>
 }
 
+// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<i1>) -> tensor<*xf32> {
+func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
+  %0 = tf_executor.graph {
+    %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
+// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32>
+    %res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control)
+    tf_executor.fetch %res#0 : tensor<*xf32>
+  }
+  return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> {
 func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
   %0 = tf_executor.graph {