[mhlo] Restrict return type for mhlo.while condition to "0-rank tensor of type i1"

PiperOrigin-RevId: 426212012
Change-Id: Ideb7e4c400c37d53f824c8a641cc46cacff4b0f8
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index b041a54..aee0523 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5564,7 +5564,7 @@
              << condReturnOp->getNumOperands();
     auto operandType =
         condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
-    if (!operandType ||  // TODO(b/210930774): operandType.getRank() != 0 ||
+    if (!operandType || operandType.getRank() != 0 ||
         !operandType.getElementType().isa<IntegerType>() ||
         operandType.getElementType().cast<IntegerType>().getWidth() != 1)
       return condReturnOp.emitOpError()
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
index ace2012..e492095 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
@@ -12,7 +12,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -34,7 +35,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -56,7 +58,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -78,7 +81,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -89,14 +93,56 @@
 
 // -----
 
-func @while_with_cond_return_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+func @while_with_cond_return_width_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
   %cst_0 = arith.constant dense<0> : tensor<1xi32>
   %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
   %cst_2 = arith.constant dense<1.00> : tensor<1xf32>
   %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
-  // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<2xi32>'}}
-    "mhlo.return"(%arg2) : (tensor<2xi32>) -> ()
+    %2 = "mhlo.reshape"(%arg1) : (tensor<1xi32>) -> tensor<i32>
+    // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<i32>'}}
+    "mhlo.return"(%2) : (tensor<i32>) -> ()
+  },  {
+  ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+    %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
+    %4 = mhlo.add %3, %arg4 : tensor<3xf32>
+    "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> ()
+  }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>)
+  return %1#3: tensor<3xf32>
+}
+
+// -----
+
+func @while_with_cond_return_rank_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+  %cst_0 = arith.constant dense<0> : tensor<1xi32>
+  %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
+  %cst_2 = arith.constant dense<1.00> : tensor<1xf32>
+  %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
+  ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+    %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
+    %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<1xi1>'}}
+    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+  },  {
+  ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+    %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
+    %4 = mhlo.add %3, %arg4 : tensor<3xf32>
+    "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> ()
+  }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>)
+  return %1#3: tensor<3xf32>
+}
+
+// -----
+
+func @while_with_cond_return_type_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+  %cst_0 = arith.constant dense<0> : tensor<1xi32>
+  %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
+  %cst_2 = arith.constant dense<1.00> : tensor<1xf32>
+  %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
+  ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+    %2 = "mhlo.reshape"(%arg3) : (tensor<1xf32>) -> tensor<f32>
+    // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<f32>'}}
+    "mhlo.return"(%2) : (tensor<f32>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -117,7 +163,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     // expected-error @+1 {{'mhlo.return' op type mismatch between operand #3 and the enclosing WhileOp returned value: 'tensor<1xf32>' vs 'tensor<3xf32>'}}
@@ -137,8 +184,9 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
   // expected-error @+1 {{'mhlo.return' op expects a single operand for while condition body return, got 2}}
-    "mhlo.return"(%4, %4) : (tensor<1xi1>, tensor<1xi1>) -> ()
+    "mhlo.return"(%5, %5) : (tensor<i1>, tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -159,7 +207,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
index a056a20..2f03a07 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
@@ -9,11 +9,12 @@
   ^bb0(%arg1: tensor<1xi64>):
 
     // CHECK: %[[VAL_3:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_2]]) {comparison_direction = "LT", name = "compare.2"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
-    // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_3]] [] : tensor<1xi1> into tensor<i1>
-    // CHECK: %[[VAL_4:.*]] = tensor.extract %[[COLLAPSED]][] : tensor<i1>
+    // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<1xi1>) -> tensor<i1>
+    // CHECK: %[[VAL_4:.*]] = tensor.extract %[[RESHAPE]][] : tensor<i1>
     // CHECK: scf.condition(%[[VAL_4]]) %[[VAL_2]] : tensor<1xi64>
     %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
-    "mhlo.return"(%1) : (tensor<1xi1>) -> ()
+    %2 = "mhlo.reshape"(%1) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%2) : (tensor<i1>) -> ()
 
   // CHECK: } do {
   // CHECK: ^bb0(%[[VAL_5:.*]]: tensor<1xi64>):
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
index d088b4b..b6a7581 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
@@ -2394,7 +2394,8 @@
     %2 = arith.constant dense<0> : tensor<i32>
     %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
     %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+    %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+    "mhlo.return"(%5) : (tensor<i1>) -> ()
   },  {
   ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
     %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>