Fix parser bug for tf_executor.Switch

Fixes the following error for tf_executor.Switch with one or more valid control
inputs: `custom op 'tf_executor.Switch' expected 2 operands`.
PiperOrigin-RevId: 306755642
Change-Id: I960bf399bfa4638090d623b8ca2b5143c707ed42
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index 8d20281..d5ecbf3 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -474,7 +474,7 @@
 ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 2> op_infos;
   SmallVector<Type, 1> types;
-  if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types))
+  if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
     return failure();
   if (types.size() != 1)
     return parser.emitError(parser.getNameLoc())
@@ -486,12 +486,15 @@
   // type).
   if (types.front().isa<FunctionType>()) {
     FunctionType type = types.front().cast<FunctionType>();
-    if (type.getNumInputs() != 2)
+    if (type.getNumInputs() < 2)
       return parser.emitError(parser.getNameLoc())
              << " expects a single data type and a predicate";
     result.types.assign(type.getResults().begin(), type.getResults().end());
     types.assign(type.getInputs().begin(), type.getInputs().end());
   } else {
+    if (op_infos.size() < 2)
+      return parser.emitError(parser.getNameLoc())
+             << " expects a single data type and a predicate";
     Type control_type = ControlType::get(parser.getBuilder().getContext());
     result.types.append(2, types[0]);
     result.types.push_back(control_type);
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
index 6282ab1..c048db5 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
@@ -187,6 +187,26 @@
   return %result : tensor<*xf32>
 }
 
+// CHECK-LABEL: func @switch_with_control_inputs(
+func @switch_with_control_inputs(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
+  %result = tf_executor.graph {
+// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
+    %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor<i1>
+    tf_executor.fetch %1#0 : tensor<i1>
+  }
+  return %result : tensor<i1>
+}
+
+// CHECK-LABEL: func @switch_with_control_inputs_functional(
+func @switch_with_control_inputs_functional(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
+  %result = tf_executor.graph {
+// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
+    %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor<i1>, tensor<i1>, !tf_executor.control, !tf_executor.control) -> (tensor<i1>, tensor<i1>, !tf_executor.control)
+    tf_executor.fetch %1#0 : tensor<i1>
+  }
+  return %result : tensor<i1>
+}
+
 // CHECK-LABEL: func @switchN(
 func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
   %fetches = tf_executor.graph {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
index a249090..1fdc99d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir
@@ -333,7 +333,7 @@
 
 // -----
 
-// Check that a switch always takes two arguments.
+// Check that a switch always needs at least two arguments.
 func @invalid_switch(%arg0: tensor<*xf32>) {
   tf_executor.graph {
     %true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
@@ -344,6 +344,17 @@
 
 // -----
 
+// Check that a switch always needs at least two arguments.
+func @invalid_switch(%arg0: tensor<*xf32>) {
+  tf_executor.graph {
+    %true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32>
+// expected-error@-1 {{custom op 'tf_executor.Switch'  expects a single data type and a predicate}}
+  }
+  return
+}
+
+// -----
+
 // Check that a switch second argument must be a valid predicate (i1).
 func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
   %result = tf_executor.graph {