Split out compare op handling from XlaBuilder::BinaryOpNoBroadcast
This allows error handling for the compare op to be shared with the MLIR builder.
PiperOrigin-RevId: 305993768
Change-Id: Ia85b9fe6dbf799f0fc4797e2897d19ba658883c3
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index e20f854..de2dfec 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -519,6 +519,7 @@
":hlo",
":hlo_utils",
":type_to_shape",
+ "//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 0bdf8eb..8bf0362 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -21,6 +21,7 @@
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
@@ -102,12 +103,21 @@
return MakeXlaOp(op.getResult());
}
-XlaOp MlirHloBuilder::BinaryOpNoBroadcast(
- HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
- absl::optional<ComparisonDirection> direction) {
+StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
+ XlaOp rhs,
+ ComparisonDirection direction) {
+ TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+ shape, builder_));
+ auto op = builder_.create<mlir::xla_hlo::CompareOp>(
+ loc_, ty, GetValue(lhs), GetValue(rhs),
+ /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(),
+ builder_.getStringAttr(ComparisonDirectionToString(direction)));
+ return MakeXlaOp(op.getResult());
+}
+
+XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+ XlaOp lhs, XlaOp rhs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- if (direction.has_value())
- return Unimplemented("direction attribute not yet supported");
return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{});
});
}
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 604db60..8534562 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -97,9 +97,11 @@
const Shape& shape, XlaOp operand,
absl::Span<const int64> broadcast_dimensions) override;
- XlaOp BinaryOpNoBroadcast(
- HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
- absl::optional<ComparisonDirection> direction) override;
+ StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+ ComparisonDirection direction) override;
+
+ XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
+ XlaOp rhs) override;
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) override;
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
index c1ee15f..9bb5a1e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
@@ -90,6 +90,13 @@
return %0 : tensor<2xf32>
}
+// CHECK-LABEL: func @greater
+func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
+ // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
+ %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ return %0: tensor<2xi1>
+}
+
// TODO(hinsu): Add a test with variant type once one of the ops supporting
// the type is whitelisted. It should be rejected with unsupported type remark.
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 7ae18eb..bf7bfe8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -78,7 +78,8 @@
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// all tf2xla kernels.
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) || isa<TF::CastOp>(op) ||
- isa<TF::InvOp>(op) || isa<TF::SelectV2Op>(op);
+ isa<TF::GreaterOp>(op) || isa<TF::InvOp>(op) ||
+ isa<TF::SelectV2Op>(op);
}
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 3dd2b93..807cbe9 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -30,6 +30,7 @@
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -561,33 +562,40 @@
AddBroadcastSequence(shape, updated_rhs));
}
- return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs,
- direction);
- });
-}
-
-XlaOp XlaBuilder::BinaryOpNoBroadcast(
- HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
- absl::optional<ComparisonDirection> direction) {
- return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
- *instr.mutable_shape() = shape.ToProto();
if (binop == HloOpcode::kCompare) {
if (!direction.has_value()) {
return InvalidArgument(
"kCompare expects a ComparisonDirection, but none provided.");
}
- instr.set_comparison_direction(ComparisonDirectionToString(*direction));
- } else if (direction.has_value()) {
+ return Compare(shape, updated_lhs, updated_rhs, *direction);
+ }
+
+ if (direction.has_value()) {
return InvalidArgument(
"A comparison direction is provided for a non-compare opcode: %s.",
HloOpcodeString(binop));
}
+ return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
+ });
+}
+XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+ XlaOp lhs, XlaOp rhs) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), binop, {lhs, rhs});
});
}
+StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+ ComparisonDirection direction) {
+ HloInstructionProto instr;
+ instr.set_comparison_direction(ComparisonDirectionToString(direction));
+ *instr.mutable_shape() = shape.ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
+}
+
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
XlaOp updated_lhs = lhs;
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 15411ed..06fc518 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -662,11 +662,14 @@
absl::Span<const int64> broadcast_dimensions,
absl::optional<ComparisonDirection> direction = absl::nullopt);
+ // Internal helper method for binary op compare without broadcast dimensions.
+ virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
+ ComparisonDirection direction);
+
// Internal helper method that does the building for an arbitrary binary op
// with same ranked operands that doesn't broadcast.
- virtual XlaOp BinaryOpNoBroadcast(
- HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
- absl::optional<ComparisonDirection> direction);
+ virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
+ XlaOp lhs, XlaOp rhs);
// Internal helper method that does the building for an arbitrary ternary op.
XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);