[mhlo] Restrict return type for mhlo.while condition to "0-rank tensor of type i1"
PiperOrigin-RevId: 426212012
Change-Id: Ideb7e4c400c37d53f824c8a641cc46cacff4b0f8
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index b041a54..aee0523 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5564,7 +5564,7 @@
<< condReturnOp->getNumOperands();
auto operandType =
condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
- if (!operandType || // TODO(b/210930774): operandType.getRank() != 0 ||
+ if (!operandType || operandType.getRank() != 0 ||
!operandType.getElementType().isa<IntegerType>() ||
operandType.getElementType().cast<IntegerType>().getWidth() != 1)
return condReturnOp.emitOpError()
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
index ace2012..e492095 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/invalid_while_op.mlir
@@ -12,7 +12,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -34,7 +35,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -56,7 +58,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -78,7 +81,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -89,14 +93,56 @@
// -----
-func @while_with_cond_return_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+func @while_with_cond_return_width_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%cst_0 = arith.constant dense<0> : tensor<1xi32>
%cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
%cst_2 = arith.constant dense<1.00> : tensor<1xf32>
%1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
- // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<2xi32>'}}
- "mhlo.return"(%arg2) : (tensor<2xi32>) -> ()
+ %2 = "mhlo.reshape"(%arg1) : (tensor<1xi32>) -> tensor<i32>
+ // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<i32>'}}
+ "mhlo.return"(%2) : (tensor<i32>) -> ()
+ }, {
+ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+ %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
+ %4 = mhlo.add %3, %arg4 : tensor<3xf32>
+ "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> ()
+ }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>)
+ return %1#3: tensor<3xf32>
+}
+
+// -----
+
+func @while_with_cond_return_rank_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+ %cst_0 = arith.constant dense<0> : tensor<1xi32>
+ %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
+ %cst_2 = arith.constant dense<1.00> : tensor<1xf32>
+ %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
+ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+ %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
+ %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<1xi1>'}}
+ "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ }, {
+ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+ %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
+ %4 = mhlo.add %3, %arg4 : tensor<3xf32>
+ "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> ()
+ }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>)
+ return %1#3: tensor<3xf32>
+}
+
+// -----
+
+func @while_with_cond_return_type_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32> {
+ %cst_0 = arith.constant dense<0> : tensor<1xi32>
+ %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32>
+ %cst_2 = arith.constant dense<1.00> : tensor<1xf32>
+ %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({
+ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
+ %2 = "mhlo.reshape"(%arg3) : (tensor<1xf32>) -> tensor<f32>
+ // expected-error @+1 {{'mhlo.return' op expects a zero-ranked tensor of i1, got 'tensor<f32>'}}
+ "mhlo.return"(%2) : (tensor<f32>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -117,7 +163,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
// expected-error @+1 {{'mhlo.return' op type mismatch between operand #3 and the enclosing WhileOp returned value: 'tensor<1xf32>' vs 'tensor<3xf32>'}}
@@ -137,8 +184,9 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
// expected-error @+1 {{'mhlo.return' op expects a single operand for while condition body return, got 2}}
- "mhlo.return"(%4, %4) : (tensor<1xi1>, tensor<1xi1>) -> ()
+ "mhlo.return"(%5, %5) : (tensor<i1>, tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
@@ -159,7 +207,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
index a056a20..2f03a07 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
@@ -9,11 +9,12 @@
^bb0(%arg1: tensor<1xi64>):
// CHECK: %[[VAL_3:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_2]]) {comparison_direction = "LT", name = "compare.2"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
- // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_3]] [] : tensor<1xi1> into tensor<i1>
- // CHECK: %[[VAL_4:.*]] = tensor.extract %[[COLLAPSED]][] : tensor<i1>
+ // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<1xi1>) -> tensor<i1>
+ // CHECK: %[[VAL_4:.*]] = tensor.extract %[[RESHAPE]][] : tensor<i1>
// CHECK: scf.condition(%[[VAL_4]]) %[[VAL_2]] : tensor<1xi64>
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
- "mhlo.return"(%1) : (tensor<1xi1>) -> ()
+ %2 = "mhlo.reshape"(%1) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%2) : (tensor<i1>) -> ()
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_5:.*]]: tensor<1xi64>):
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
index d088b4b..b6a7581 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
@@ -2394,7 +2394,8 @@
%2 = arith.constant dense<0> : tensor<i32>
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%4 = "mhlo.compare"(%arg1, %3) {comparison_direction = "LT"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- "mhlo.return"(%4) : (tensor<1xi1>) -> ()
+ %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor<i1>
+ "mhlo.return"(%5) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>):
%3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32>