[mhlo] Element-types of non-predicate operands to mhlo.select op must match

PiperOrigin-RevId: 465510420
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index b793c1d..7216c6d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5428,18 +5428,23 @@
 //===----------------------------------------------------------------------===//
 
 LogicalResult SelectOp::verify() {
-  // Either, all operands could be the same shape ...
-  if (succeeded(verifyCompatibleShapes(getOperandTypes()))) return success();
+  // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
+  //   (a) have the same element type, and
+  //   (b) have compatible shapes (i.e. the same shape and/or at least one
+  //       dynamic shape)
+  if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
+    return emitOpError()
+           << "requires compatible types for non-predicate operands";
 
-  // ... or the predicate could be a scalar and the remaining two operands could
-  // be of the same shape.
+  // The predicate, if not-scalar, should have the same shape as the remaining
+  // operands.
   auto predTy = pred().getType().dyn_cast<RankedTensorType>();
   bool predMayBeScalar = !predTy || predTy.getRank() == 0;
-  if (!predMayBeScalar || failed(verifyCompatibleShapes(
-                              {on_true().getType(), on_false().getType()}))) {
-    return emitOpError()
-           << "requires the same type for all operands and results";
-  }
+  if (predMayBeScalar) return success();
+
+  if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
+    return emitOpError() << "requires the same shape for all operands";
+
   return success();
 }
 
@@ -5494,18 +5499,7 @@
     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
   SelectOp::Adaptor op(operands, attributes);
   auto trueType = op.on_true().getType().cast<TensorType>();
-  auto falseType = op.on_true().getType().cast<TensorType>();
-
-  // Check for type compatibility in the select op. This requires that the two
-  // non-predicate operands:
-  //   (a) have the same element type
-  //   (b) have compatible shapes (i.e. the same shape and/or at least one
-  //       dynamic shape)
-  if (trueType.getElementType() != falseType.getElementType() ||
-      failed(mlir::verifyCompatibleShape(trueType, falseType))) {
-    return emitOptionalError(location, "incompatible operand types: ", trueType,
-                             " and ", falseType);
-  }
+  auto falseType = op.on_false().getType().cast<TensorType>();
 
   // The output shape should be the most general of the operand shapes at each
   // dimension.
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
index 9148b67..99aacef 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
@@ -1423,7 +1423,7 @@
 // -----
 
 func.func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{requires the same type for all operands and results}}
+  // expected-error@+1 {{requires the same shape for all operands}}
   %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   func.return %0 : tensor<2x3xi32>
 }
@@ -1431,29 +1431,36 @@
 // -----
 
 func.func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{op requires the same type for all operands and results}}
+  // expected-error@+1 {{requires compatible types for non-predicate operands}}
   %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   func.return %0 : tensor<2x3xi32>
 }
 
 // -----
 
-func.func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{op requires the same type for all operands and results}}
-  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+func.func @select_when_pred_is_scalar(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   func.return %0 : tensor<2x3xi32>
 }
 
 // -----
 
-func.func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{requires the same type for all operands and results}}
-  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+func.func @select_element_type_mismatch(%arg0: tensor<i1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // expected-error@+1 {{requires compatible types for non-predicate operands}}
+  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   func.return %0 : tensor<2x3xi32>
 }
 
 // -----
 
+func.func @select_element_type_mismatch(%arg0: tensor<i1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf64>) -> tensor<2x3xf64> {
+  // expected-error@+1 {{requires compatible types for non-predicate operands}}
+  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xf32>, tensor<2x3xf64>) -> tensor<2x3xf64>
+  func.return %0 : tensor<2x3xf64>
+}
+
+// -----
+
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
   %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>