Added comparison (gt, ge, lt, le) patterns for TF to XLA.

PiperOrigin-RevId: 269018428
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index e6efdc8..a16e5db 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -245,7 +245,7 @@
 }
 
 def HLO_CompareOp: HLO_Op<"compare",
-      [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_CompareOp {
+      [NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp {
   let arguments = (ins
     HLO_Tensor:$lhs,
     HLO_Tensor:$rhs,
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 4af6726..a347f61 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -141,6 +141,67 @@
   return %0: tensor<1x2xi32>
 }
 
+
+//===----------------------------------------------------------------------===//
+// Compare op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @greater
+func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+  // CHECK-NEXT:  "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
+  %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  return %0: tensor<2xi1>
+}
+
+// CHECK-LABEL: func @broadcast_greater
+func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
+  // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"}
+  %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+  return %0: tensor<1x2xi1>
+}
+
+// CHECK-LABEL: func @greater_equal
+func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+  // CHECK-NEXT:  "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
+  %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  return %0: tensor<2xi1>
+}
+
+// CHECK-LABEL: func @broadcast_greater_equal
+func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
+  // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"}
+  %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+  return %0: tensor<1x2xi1>
+}
+
+// CHECK-LABEL: func @less
+func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+  // CHECK-NEXT:  "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
+  %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  return %0: tensor<2xi1>
+}
+
+// CHECK-LABEL: func @broadcast_less
+func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
+  // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"}
+  %0 = "tf.Less"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+  return %0: tensor<1x2xi1>
+}
+
+// CHECK-LABEL: func @less_equal
+func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+  // CHECK-NEXT:  "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
+  %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  return %0: tensor<2xi1>
+}
+
+// CHECK-LABEL: func @broadcast_less_equal
+func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
+  // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"}
+  %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+  return %0: tensor<1x2xi1>
+}
+
 //===----------------------------------------------------------------------===//
 // Concat op legalizations.
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index fe930e6..cdc09e4 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -88,6 +88,19 @@
   def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
 
 //===----------------------------------------------------------------------===//
+// Compare op patterns.
+//===----------------------------------------------------------------------===//
+
+class DirectComparePat<Op FromOp, StrEnumAttrCase direction>
+  : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
+        (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>;
+
+def : DirectComparePat<TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT>;
+def : DirectComparePat<TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE>;
+def : DirectComparePat<TF_LessOp, HLO_COMPARISON_DIRECTION_LT>;
+def : DirectComparePat<TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE>;
+
+//===----------------------------------------------------------------------===//
 // Concat op patterns.
 //===----------------------------------------------------------------------===//