Add a method to describe numerical classifications for comparisons.
PiperOrigin-RevId: 441513245
diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h
index 5688c91..2189121 100644
--- a/tensorflow/compiler/xla/comparison_util.h
+++ b/tensorflow/compiler/xla/comparison_util.h
@@ -134,6 +134,13 @@
return primitive_type_ == PrimitiveType::U32 && IsTotalOrder();
}
+ inline bool IsIntegralPrimitiveType() const {
+ return primitive_util::IsIntegralType(primitive_type_);
+ }
+ inline bool IsFloatingPointPrimitiveType() const {
+ return primitive_util::IsFloatingPointType(primitive_type_);
+ }
+
// Returns whether (a dir a) is always true for this comparison.
bool IsReflexive() const;