[XLA] Verify the comparison type for comparisons

PiperOrigin-RevId: 343577367
Change-Id: Ic39711697188f4927d0b23eb7f4b6cda6a14630a
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 28f7682..171afa4 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -234,6 +234,7 @@
         "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/compiler/xla/service:shape_inference",
         "//tensorflow/core:lib",
+        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 9d2d190..3350653 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -34,6 +34,7 @@
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -46,6 +47,7 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 
@@ -3350,6 +3352,11 @@
     }
 
     if (!need_rewrite) {
+      if (opcode == HloOpcode::kCompare) {
+        CHECK(!instr_proto->comparison_type().empty());
+        new_instr->set_comparison_type(
+            ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
+      }
       *new_instr->mutable_name() =
           GetFullName(instr_proto->opcode(), kNameSeparator, id);
       return Status::OK();
@@ -4009,11 +4016,26 @@
   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
 }
 
+static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
+                               absl::Span<const int64> broadcast_dimensions,
+                               ComparisonDirection comparison_direction) {
+  auto b = lhs.builder();
+  return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
+    auto operand_element_type = operand_shape.element_type();
+    auto compare_type =
+        primitive_util::IsFloatingPointType(operand_element_type)
+            ? Comparison::Type::kFloatTotalOrder
+            : Comparison::DefaultComparisonType(operand_element_type);
+    return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
+                   compare_type);
+  });
+}
+
 XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kEq);
 }
 
 XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
@@ -4023,9 +4045,8 @@
 
 XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kNe);
 }
 
 XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
@@ -4035,9 +4056,8 @@
 
 XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kGe);
 }
 
 XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
@@ -4047,9 +4067,8 @@
 
 XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kGt);
 }
 
 XlaOp Le(const XlaOp lhs, const XlaOp rhs,
@@ -4059,10 +4078,10 @@
 
 XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  auto compare_type = Comparison::Type::kFloatTotalOrder;
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
-                 compare_type);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kLe);
 }
+
 XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
          absl::Span<const int64> broadcast_dimensions) {
   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
@@ -4070,8 +4089,8 @@
 
 XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
                    absl::Span<const int64> broadcast_dimensions) {
-  return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
-                 Comparison::Type::kFloatTotalOrder);
+  return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
+                           ComparisonDirection::kLt);
 }
 
 XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b1034bf..5fe89e8 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3721,6 +3721,7 @@
         ":hlo_casting_utils",
         ":hlo_pass",
         ":shape_inference",
+        "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index f2ea03f..e43f68f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -25,6 +25,7 @@
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
 #include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8c66d00..84e4fe6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -19,6 +19,7 @@
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -1746,6 +1747,31 @@
           ShapeUtil::HumanString(operand_shape));
     }
   }
+  if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
+    const Shape& operand_shape = comparison->operand(1)->shape();
+    PrimitiveType operand_element_type = operand_shape.element_type();
+    Comparison::Type default_comparison_type =
+        Comparison::DefaultComparisonType(operand_element_type);
+    if (primitive_util::IsFloatingPointType(operand_element_type)) {
+      if (comparison->type() != Comparison::Type::kFloat &&
+          comparison->type() != Comparison::Type::kFloatTotalOrder) {
+        return FailedPrecondition(
+            "Expected comparison type %s or %s.\n"
+            "actual: %s\noperand: %s\n",
+            ComparisonTypeToString(Comparison::Type::kFloat),
+            ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
+            ComparisonTypeToString(comparison->type()),
+            ShapeUtil::HumanString(operand_shape));
+      }
+    } else if (comparison->type() != default_comparison_type) {
+      return FailedPrecondition(
+          "Expected comparison type %s.\n"
+          "actual: %s\noperand: %s\n",
+          ComparisonTypeToString(default_comparison_type),
+          ComparisonTypeToString(comparison->type()),
+          ShapeUtil::HumanString(operand_shape));
+    }
+  }
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0df3016..c6c09e3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1220,5 +1220,77 @@
                         "needs to be collective-permute-start, found tuple"));
 }
 
+TEST_F(HloVerifierTest, ComparisonTypeFloat) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = f32[] parameter(0)
+   ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeSigned) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = s32[] parameter(0)
+   ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type SIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = u32[] parameter(0)
+   ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type UNSIGNED"));
+}
+
+TEST_F(HloVerifierTest, ComparisonTypePred) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY RngOperandElementTypesNotMatch {
+   p0 = pred[] parameter(0)
+   ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Expected comparison type UNSIGNED"));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index fe27a8c..69916f6 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -554,7 +554,7 @@
   abs.129 = f32[4]{0} abs(subtract.126)
   constant.130 = f32[] constant(inf)
   broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
-  compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
+  compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
   not.133 = pred[4]{0} not(compare.132)
   and.134 = pred[4]{0} and(not.128, not.133)
   add.135 = f32[4]{0} add(add.124, add.89)
@@ -577,7 +577,7 @@
   abs.219 = f32[4]{0} abs(subtract.216)
   constant.220 = f32[] constant(inf)
   broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
-  compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
+  compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
   not.223 = pred[4]{0} not(compare.222)
   and.224 = pred[4]{0} and(not.218, not.223)
   add.225 = f32[4]{0} add(add.214, add.179)
@@ -600,7 +600,7 @@
   abs.309 = f32[4]{0} abs(subtract.306)
   constant.310 = f32[] constant(inf)
   broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
-  compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
+  compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
   not.313 = pred[4]{0} not(compare.312)
   and.314 = pred[4]{0} and(not.308, not.313)
   add.315 = f32[4]{0} add(add.304, add.269)
@@ -623,7 +623,7 @@
   abs.399 = f32[4]{0} abs(subtract.396)
   constant.400 = f32[] constant(inf)
   broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
-  compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
+  compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
   not.403 = pred[4]{0} not(compare.402)
   and.404 = pred[4]{0} and(not.398, not.403)
   add.405 = f32[4]{0} add(add.394, add.359)
@@ -646,7 +646,7 @@
   abs.489 = f32[4]{0} abs(subtract.486)
   constant.490 = f32[] constant(inf)
   broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
-  compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
+  compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
   not.493 = pred[4]{0} not(compare.492)
   and.494 = pred[4]{0} and(not.488, not.493)
   add.495 = f32[4]{0} add(add.484, add.449)
@@ -669,7 +669,7 @@
   abs.579 = f32[4]{0} abs(subtract.576)
   constant.580 = f32[] constant(inf)
   broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
-  compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
+  compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
   not.583 = pred[4]{0} not(compare.582)
   and.584 = pred[4]{0} and(not.578, not.583)
   add.585 = f32[4]{0} add(add.574, add.539)
@@ -692,7 +692,7 @@
   abs.669 = f32[4]{0} abs(subtract.666)
   constant.670 = f32[] constant(inf)
   broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
-  compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
+  compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
   not.673 = pred[4]{0} not(compare.672)
   and.674 = pred[4]{0} and(not.668, not.673)
   add.675 = f32[4]{0} add(add.664, add.629)
@@ -715,7 +715,7 @@
   abs.759 = f32[4]{0} abs(subtract.756)
   constant.760 = f32[] constant(inf)
   broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
-  compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
+  compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
   not.763 = pred[4]{0} not(compare.762)
   and.764 = pred[4]{0} and(not.758, not.763)
   add.765 = f32[4]{0} add(add.754, add.719)
@@ -738,7 +738,7 @@
   abs.849 = f32[4]{0} abs(subtract.846)
   constant.850 = f32[] constant(inf)
   broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
-  compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
+  compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
   not.853 = pred[4]{0} not(compare.852)
   and.854 = pred[4]{0} and(not.848, not.853)
   add.855 = f32[4]{0} add(add.844, add.809)