[MLIR][KernelGen] Implement InferShapedTypeOpInterface for `mhlo.compare/select`

PiperOrigin-RevId: 332227340
Change-Id: I7d90e510287d523765d8e59231f78d7f73fb3b79
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 22afeaa..351e8bd 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -678,9 +678,10 @@
   let hasCanonicalizer = 1;
 }
 
-def HLO_CompareOp: HLO_Op<"compare",
-      [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
-      BASE_HLO_CompareOp {
+def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
+    SameOperandsAndResultShape,
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+    ["reifyReturnTypeShapes"]>]>, BASE_HLO_CompareOp {
   let arguments = (ins
     HLO_Tensor:$lhs,
     HLO_Tensor:$rhs,
@@ -1152,7 +1153,10 @@
 }
 
 // TODO(jpienaar): Add broadcastable trait.
-def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_SelectOp {
+def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+    ["reifyReturnTypeShapes"]>, DeclareOpInterfaceMethods<InferTypeOpInterface>,
+    ]>, BASE_HLO_SelectOp {
   let arguments = (ins
     HLO_PredTensor:$pred,
     HLO_Tensor:$on_true,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index fcd91f8..6711a91 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -1678,6 +1678,20 @@
   return success();
 }
 
+LogicalResult SelectOp::inferReturnTypeComponents(
+    mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
+    mlir::DictionaryAttr, mlir::RegionRange,
+    llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
+  // TODO(b/168772852)
+  return failure();
+}
+
+LogicalResult SelectOp::reifyReturnTypeShapes(
+    OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
+  return deriveShapeFromFirstOperand(&builder, getOperation(),
+                                     &reifiedReturnShapes);
+}
+
 //===----------------------------------------------------------------------===//
 // PadOp
 //===----------------------------------------------------------------------===//
@@ -2473,9 +2487,22 @@
   build(builder, result, new_type, lhs, rhs, comparison_direction);
 }
 
+LogicalResult CompareOp::inferReturnTypeComponents(
+    mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
+    mlir::DictionaryAttr, mlir::RegionRange,
+    llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
+  // TODO(b/168772852)
+  return failure();
+}
+
+LogicalResult CompareOp::reifyReturnTypeShapes(
+    OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
+  return deriveShapeFromFirstOperand(&builder, getOperation(),
+                                     &reifiedReturnShapes);
+}
+
 }  // namespace mhlo
 }  // namespace mlir
-
 #define GET_OP_CLASSES
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir
index d226c92..0738459 100644
--- a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @broadcast_add
 // Note that all broadcast_ops are expanded from the same template, so
diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir
new file mode 100644
index 0000000..d626f52
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: @select
+// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
+func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
+    -> tensor<2xi64> {
+  // CHECK: %[[C2:.*]] = constant 2 : i64
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
+  // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
+  // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
+  // CHECK: return %[[SHAPE]] : tensor<2xi64>
+  %0 = "mhlo.select"(%pred, %a, %b)
+      : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+  %1 = "mhlo_test.reify_return_type_shapes"(%0)
+      : (tensor<2x?xf32>) -> tensor<2xi64>
+  return %1 : tensor<2xi64>
+}
+
+// -----
+// CHECK-LABEL: @compare
+// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
+func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
+  // CHECK: %[[C2:.*]] = constant 2 : i64
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
+  // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
+  // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
+  // CHECK: return %[[SHAPE]] : tensor<2xi64>
+  %0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" }
+      : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
+  %1 = "mhlo_test.reify_return_type_shapes"(%0)
+      : (tensor<2x?xi1>) -> tensor<2xi64>
+  return %1 : tensor<2xi64>
+}
+