Split out compare op handling from XlaBuilder::BinaryOpNoBroadcast

This allows error handling for the compare op to be shared with the MLIR builder.

PiperOrigin-RevId: 305993768
Change-Id: Ia85b9fe6dbf799f0fc4797e2897d19ba658883c3
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index e20f854..de2dfec 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -519,6 +519,7 @@
         ":hlo",
         ":hlo_utils",
         ":type_to_shape",
+        "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/client:xla_builder",
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 0bdf8eb..8bf0362 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -21,6 +21,7 @@
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/service/shape_inference.h"
 
 namespace xla {
@@ -102,12 +103,21 @@
   return MakeXlaOp(op.getResult());
 }
 
-XlaOp MlirHloBuilder::BinaryOpNoBroadcast(
-    HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
-    absl::optional<ComparisonDirection> direction) {
+StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
+                                        XlaOp rhs,
+                                        ComparisonDirection direction) {
+  TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+                                         shape, builder_));
+  auto op = builder_.create<mlir::xla_hlo::CompareOp>(
+      loc_, ty, GetValue(lhs), GetValue(rhs),
+      /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(),
+      builder_.getStringAttr(ComparisonDirectionToString(direction)));
+  return MakeXlaOp(op.getResult());
+}
+
+XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+                                          XlaOp lhs, XlaOp rhs) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    if (direction.has_value())
-      return Unimplemented("direction attribute not yet supported");
     return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{});
   });
 }
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 604db60..8534562 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -97,9 +97,11 @@
       const Shape& shape, XlaOp operand,
       absl::Span<const int64> broadcast_dimensions) override;
 
-  XlaOp BinaryOpNoBroadcast(
-      HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
-      absl::optional<ComparisonDirection> direction) override;
+  StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+                          ComparisonDirection direction) override;
+
+  XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
+                            XlaOp rhs) override;
 
   StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
                                  absl::Span<const XlaOp> operands) override;
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
index c1ee15f..9bb5a1e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
@@ -90,6 +90,13 @@
   return %0 : tensor<2xf32>
 }
 
+// CHECK-LABEL: func @greater
+func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+  // CHECK-NEXT:  "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
+  %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  return %0: tensor<2xi1>
+}
+
 // TODO(hinsu): Add a test with variant type once one of the ops supporting
 // the type is whitelisted. It should be rejected with unsupported type remark.
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 7ae18eb..bf7bfe8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -78,7 +78,8 @@
   // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
   // all tf2xla kernels.
   return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) || isa<TF::CastOp>(op) ||
-         isa<TF::InvOp>(op) || isa<TF::SelectV2Op>(op);
+         isa<TF::GreaterOp>(op) || isa<TF::InvOp>(op) ||
+         isa<TF::SelectV2Op>(op);
 }
 
 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 3dd2b93..807cbe9 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -30,6 +30,7 @@
 #include "absl/strings/str_join.h"
 #include "tensorflow/compiler/xla/client/sharding_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/execution_options_util.h"
 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -561,33 +562,40 @@
                           AddBroadcastSequence(shape, updated_rhs));
     }
 
-    return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs,
-                               direction);
-  });
-}
-
-XlaOp XlaBuilder::BinaryOpNoBroadcast(
-    HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
-    absl::optional<ComparisonDirection> direction) {
-  return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-    *instr.mutable_shape() = shape.ToProto();
     if (binop == HloOpcode::kCompare) {
       if (!direction.has_value()) {
         return InvalidArgument(
             "kCompare expects a ComparisonDirection, but none provided.");
       }
-      instr.set_comparison_direction(ComparisonDirectionToString(*direction));
-    } else if (direction.has_value()) {
+      return Compare(shape, updated_lhs, updated_rhs, *direction);
+    }
+
+    if (direction.has_value()) {
       return InvalidArgument(
           "A comparison direction is provided for a non-compare opcode: %s.",
           HloOpcodeString(binop));
     }
+    return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
+  });
+}
 
+XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+                                      XlaOp lhs, XlaOp rhs) {
+  return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    *instr.mutable_shape() = shape.ToProto();
     return AddInstruction(std::move(instr), binop, {lhs, rhs});
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+                                    ComparisonDirection direction) {
+  HloInstructionProto instr;
+  instr.set_comparison_direction(ComparisonDirectionToString(direction));
+  *instr.mutable_shape() = shape.ToProto();
+  return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
+}
+
 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     XlaOp updated_lhs = lhs;
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 15411ed..06fc518 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -662,11 +662,14 @@
                  absl::Span<const int64> broadcast_dimensions,
                  absl::optional<ComparisonDirection> direction = absl::nullopt);
 
+  // Internal helper method for binary op compare without broadcast dimensions.
+  virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+                                  ComparisonDirection direction);
+
   // Internal helper method that does the building for an arbitrary binary op
   // with same ranked operands that doesn't broadcast.
-  virtual XlaOp BinaryOpNoBroadcast(
-      HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
-      absl::optional<ComparisonDirection> direction);
+  virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+                                    XlaOp lhs, XlaOp rhs);
 
   // Internal helper method that does the building for an arbitrary ternary op.
   XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);