[mhlo] Element-types of non-predicate operands to mhlo.select op must match
PiperOrigin-RevId: 465510420
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index b793c1d..7216c6d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5428,18 +5428,23 @@
//===----------------------------------------------------------------------===//
LogicalResult SelectOp::verify() {
- // Either, all operands could be the same shape ...
- if (succeeded(verifyCompatibleShapes(getOperandTypes()))) return success();
+ // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
+ // (a) have the same element type, and
+ // (b) have compatible shapes (i.e. the same shape and/or at least one
+ // dynamic shape)
+ if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
+ return emitOpError()
+ << "requires compatible types for non-predicate operands";
- // ... or the predicate could be a scalar and the remaining two operands could
- // be of the same shape.
+ // The predicate, if not-scalar, should have the same shape as the remaining
+ // operands.
auto predTy = pred().getType().dyn_cast<RankedTensorType>();
bool predMayBeScalar = !predTy || predTy.getRank() == 0;
- if (!predMayBeScalar || failed(verifyCompatibleShapes(
- {on_true().getType(), on_false().getType()}))) {
- return emitOpError()
- << "requires the same type for all operands and results";
- }
+ if (predMayBeScalar) return success();
+
+ if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
+ return emitOpError() << "requires the same shape for all operands";
+
return success();
}
@@ -5494,18 +5499,7 @@
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SelectOp::Adaptor op(operands, attributes);
auto trueType = op.on_true().getType().cast<TensorType>();
- auto falseType = op.on_true().getType().cast<TensorType>();
-
- // Check for type compatibility in the select op. This requires that the two
- // non-predicate operands:
- // (a) have the same element type
- // (b) have compatible shapes (i.e. the same shape and/or at least one
- // dynamic shape)
- if (trueType.getElementType() != falseType.getElementType() ||
- failed(mlir::verifyCompatibleShape(trueType, falseType))) {
- return emitOptionalError(location, "incompatible operand types: ", trueType,
- " and ", falseType);
- }
+ auto falseType = op.on_false().getType().cast<TensorType>();
// The output shape should be the most general of the operand shapes at each
// dimension.
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
index 9148b67..99aacef 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
@@ -1423,7 +1423,7 @@
// -----
func.func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
- // expected-error@+1 {{requires the same type for all operands and results}}
+ // expected-error@+1 {{requires the same shape for all operands}}
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
func.return %0 : tensor<2x3xi32>
}
@@ -1431,29 +1431,36 @@
// -----
func.func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
- // expected-error@+1 {{op requires the same type for all operands and results}}
+ // expected-error@+1 {{requires compatible types for non-predicate operands}}
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
func.return %0 : tensor<2x3xi32>
}
// -----
-func.func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
- // expected-error@+1 {{op requires the same type for all operands and results}}
- %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+func.func @select_when_pred_is_scalar(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
func.return %0 : tensor<2x3xi32>
}
// -----
-func.func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
- // expected-error@+1 {{requires the same type for all operands and results}}
- %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+func.func @select_element_type_mismatch(%arg0: tensor<i1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ // expected-error@+1 {{requires compatible types for non-predicate operands}}
+ %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32>
func.return %0 : tensor<2x3xi32>
}
// -----
+func.func @select_element_type_mismatch(%arg0: tensor<i1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xf64>) -> tensor<2x3xf64> {
+ // expected-error@+1 {{requires compatible types for non-predicate operands}}
+ %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xf32>, tensor<2x3xf64>) -> tensor<2x3xf64>
+ func.return %0 : tensor<2x3xf64>
+}
+
+// -----
+
// CHECK-LABEL: func @slice
func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>