[XLA:ALGEBRAIC_SIMPLIFIER] cleanup the IsConvertPairNoOp to use the primitive util formulation.

PiperOrigin-RevId: 403189355
Change-Id: I4efa6af1de823b96601731459d816384b2a3ebe0
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 41fc9e2..7dcb375 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -289,35 +289,16 @@
   //    [operand_convert]         [convert]
   // (src)->convert-(intermediate)->convert-(dest)
   const HloInstruction* operand_convert = convert->operand(0);
-  CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
-  const Shape& src_shape = operand_convert->operand(0)->shape();
-  const Shape& intermediate_shape = operand_convert->shape();
-  const Shape& dest_shape = convert->shape();
-
-  const PrimitiveType src_type = src_shape.element_type();
-  const PrimitiveType intermediate_type = intermediate_shape.element_type();
-  const PrimitiveType dest_type = dest_shape.element_type();
-
-  // src_type must be equal to dest_type.
-  if (src_type != dest_type) {
+  if (operand_convert->opcode() != HloOpcode::kConvert) {
     return false;
   }
+  const PrimitiveType src_type =
+      operand_convert->operand(0)->shape().element_type();
+  const PrimitiveType intermediate_type =
+      operand_convert->shape().element_type();
 
-  // src_type must be a larger container than intermediate_type.
-  if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
-      ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
-    return false;
-  }
-
-  // Both src_type and intermediate_type must be either floating or integral.
-  bool is_conversion_floating =
-      ShapeUtil::ElementIsFloating(src_shape) &&
-      ShapeUtil::ElementIsFloating(intermediate_shape);
-  bool is_conversion_integral =
-      ShapeUtil::ElementIsIntegral(src_shape) &&
-      ShapeUtil::ElementIsIntegral(intermediate_shape);
-
-  return is_conversion_floating || is_conversion_integral;
+  return src_type == convert->shape().element_type() &&
+         primitive_util::CastPreservesValues(src_type, intermediate_type);
 }
 
 PrecisionConfig SwapOperandsInDotPrecisionConfig(PrecisionConfig config) {
@@ -3226,8 +3207,7 @@
   //    convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
   //    $TYPE1) , floor(A), A) -> a case where the first convert has a
   //    fan-out
-  if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
-      IsConvertPairNoOp(convert)) {
+  if (IsConvertPairNoOp(convert)) {
     return ReplaceInstruction(convert,
                               convert->mutable_operand(0)->mutable_operand(0));
   }