[XLA] Verify the comparison type for comparisons
PiperOrigin-RevId: 343577367
Change-Id: Ic39711697188f4927d0b23eb7f4b6cda6a14630a
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 28f7682..171afa4 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -234,6 +234,7 @@
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
+ "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 9d2d190..3350653 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -34,6 +34,7 @@
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -46,6 +47,7 @@
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
@@ -3350,6 +3352,11 @@
}
if (!need_rewrite) {
+ if (opcode == HloOpcode::kCompare) {
+ CHECK(!instr_proto->comparison_type().empty());
+ new_instr->set_comparison_type(
+ ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
+ }
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK();
@@ -4009,11 +4016,26 @@
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
}
+static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
+ absl::Span<const int64> broadcast_dimensions,
+ ComparisonDirection comparison_direction) {
+ auto b = lhs.builder();
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
+ auto operand_element_type = operand_shape.element_type();
+ auto compare_type =
+ primitive_util::IsFloatingPointType(operand_element_type)
+ ? Comparison::Type::kFloatTotalOrder
+ : Comparison::DefaultComparisonType(operand_element_type);
+ return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
+ compare_type);
+ });
+}
+
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- auto compare_type = Comparison::Type::kFloatTotalOrder;
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
- compare_type);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kEq);
}
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
@@ -4023,9 +4045,8 @@
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- auto compare_type = Comparison::Type::kFloatTotalOrder;
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
- compare_type);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kNe);
}
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
@@ -4035,9 +4056,8 @@
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- auto compare_type = Comparison::Type::kFloatTotalOrder;
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
- compare_type);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kGe);
}
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
@@ -4047,9 +4067,8 @@
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- auto compare_type = Comparison::Type::kFloatTotalOrder;
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
- compare_type);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kGt);
}
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
@@ -4059,10 +4078,10 @@
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- auto compare_type = Comparison::Type::kFloatTotalOrder;
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
- compare_type);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kLe);
}
+
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
@@ -4070,8 +4089,8 @@
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
- return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
- Comparison::Type::kFloatTotalOrder);
+ return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+ ComparisonDirection::kLt);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b1034bf..5fe89e8 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3721,6 +3721,7 @@
":hlo_casting_utils",
":hlo_pass",
":shape_inference",
+ "//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index f2ea03f..e43f68f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -25,6 +25,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8c66d00..84e4fe6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -19,6 +19,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -1746,6 +1747,31 @@
ShapeUtil::HumanString(operand_shape));
}
}
+ if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
+ const Shape& operand_shape = comparison->operand(1)->shape();
+ PrimitiveType operand_element_type = operand_shape.element_type();
+ Comparison::Type default_comparison_type =
+ Comparison::DefaultComparisonType(operand_element_type);
+ if (primitive_util::IsFloatingPointType(operand_element_type)) {
+ if (comparison->type() != Comparison::Type::kFloat &&
+ comparison->type() != Comparison::Type::kFloatTotalOrder) {
+ return FailedPrecondition(
+ "Expected comparison type %s or %s.\n"
+ "actual: %s\noperand: %s\n",
+ ComparisonTypeToString(Comparison::Type::kFloat),
+ ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
+ ComparisonTypeToString(comparison->type()),
+ ShapeUtil::HumanString(operand_shape));
+ }
+ } else if (comparison->type() != default_comparison_type) {
+ return FailedPrecondition(
+ "Expected comparison type %s.\n"
+ "actual: %s\noperand: %s\n",
+ ComparisonTypeToString(default_comparison_type),
+ ComparisonTypeToString(comparison->type()),
+ ShapeUtil::HumanString(operand_shape));
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0df3016..c6c09e3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1220,5 +1220,77 @@
"needs to be collective-permute-start, found tuple"));
}
+TEST_F(HloVerifierTest, ComparisonTypeFloat) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOperandElementTypesNotMatch {
+ p0 = f32[] parameter(0)
+ ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeSigned) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOperandElementTypesNotMatch {
+ p0 = s32[] parameter(0)
+ ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected comparison type SIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOperandElementTypesNotMatch {
+ p0 = u32[] parameter(0)
+ ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected comparison type UNSIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypePred) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOperandElementTypesNotMatch {
+ p0 = pred[] parameter(0)
+ ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnUnverifiedModule(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected comparison type UNSIGNED"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index fe27a8c..69916f6 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -554,7 +554,7 @@
abs.129 = f32[4]{0} abs(subtract.126)
constant.130 = f32[] constant(inf)
broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
- compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
+ compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
not.133 = pred[4]{0} not(compare.132)
and.134 = pred[4]{0} and(not.128, not.133)
add.135 = f32[4]{0} add(add.124, add.89)
@@ -577,7 +577,7 @@
abs.219 = f32[4]{0} abs(subtract.216)
constant.220 = f32[] constant(inf)
broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
- compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
+ compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
not.223 = pred[4]{0} not(compare.222)
and.224 = pred[4]{0} and(not.218, not.223)
add.225 = f32[4]{0} add(add.214, add.179)
@@ -600,7 +600,7 @@
abs.309 = f32[4]{0} abs(subtract.306)
constant.310 = f32[] constant(inf)
broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
- compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
+ compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
not.313 = pred[4]{0} not(compare.312)
and.314 = pred[4]{0} and(not.308, not.313)
add.315 = f32[4]{0} add(add.304, add.269)
@@ -623,7 +623,7 @@
abs.399 = f32[4]{0} abs(subtract.396)
constant.400 = f32[] constant(inf)
broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
- compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
+ compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
not.403 = pred[4]{0} not(compare.402)
and.404 = pred[4]{0} and(not.398, not.403)
add.405 = f32[4]{0} add(add.394, add.359)
@@ -646,7 +646,7 @@
abs.489 = f32[4]{0} abs(subtract.486)
constant.490 = f32[] constant(inf)
broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
- compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
+ compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
not.493 = pred[4]{0} not(compare.492)
and.494 = pred[4]{0} and(not.488, not.493)
add.495 = f32[4]{0} add(add.484, add.449)
@@ -669,7 +669,7 @@
abs.579 = f32[4]{0} abs(subtract.576)
constant.580 = f32[] constant(inf)
broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
- compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
+ compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
not.583 = pred[4]{0} not(compare.582)
and.584 = pred[4]{0} and(not.578, not.583)
add.585 = f32[4]{0} add(add.574, add.539)
@@ -692,7 +692,7 @@
abs.669 = f32[4]{0} abs(subtract.666)
constant.670 = f32[] constant(inf)
broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
- compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
+ compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
not.673 = pred[4]{0} not(compare.672)
and.674 = pred[4]{0} and(not.668, not.673)
add.675 = f32[4]{0} add(add.664, add.629)
@@ -715,7 +715,7 @@
abs.759 = f32[4]{0} abs(subtract.756)
constant.760 = f32[] constant(inf)
broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
- compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
+ compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
not.763 = pred[4]{0} not(compare.762)
and.764 = pred[4]{0} and(not.758, not.763)
add.765 = f32[4]{0} add(add.754, add.719)
@@ -738,7 +738,7 @@
abs.849 = f32[4]{0} abs(subtract.846)
constant.850 = f32[] constant(inf)
broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
- compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
+ compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
not.853 = pred[4]{0} not(compare.852)
and.854 = pred[4]{0} and(not.848, not.853)
add.855 = f32[4]{0} add(add.844, add.809)