[XLA:ALGEBRAIC_SIMPLIFIER] cleanup the IsConvertPairNoOp to use the primitive util formulation.
PiperOrigin-RevId: 403189355
Change-Id: I4efa6af1de823b96601731459d816384b2a3ebe0
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 41fc9e2..7dcb375 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -289,35 +289,16 @@
// [operand_convert] [convert]
// (src)->convert-(intermediate)->convert-(dest)
const HloInstruction* operand_convert = convert->operand(0);
- CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
- const Shape& src_shape = operand_convert->operand(0)->shape();
- const Shape& intermediate_shape = operand_convert->shape();
- const Shape& dest_shape = convert->shape();
-
- const PrimitiveType src_type = src_shape.element_type();
- const PrimitiveType intermediate_type = intermediate_shape.element_type();
- const PrimitiveType dest_type = dest_shape.element_type();
-
- // src_type must be equal to dest_type.
- if (src_type != dest_type) {
+ if (operand_convert->opcode() != HloOpcode::kConvert) {
return false;
}
+ const PrimitiveType src_type =
+ operand_convert->operand(0)->shape().element_type();
+ const PrimitiveType intermediate_type =
+ operand_convert->shape().element_type();
- // src_type must be a larger container than intermediate_type.
- if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
- ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
- return false;
- }
-
- // Both src_type and intermediate_type must be either floating or integral.
- bool is_conversion_floating =
- ShapeUtil::ElementIsFloating(src_shape) &&
- ShapeUtil::ElementIsFloating(intermediate_shape);
- bool is_conversion_integral =
- ShapeUtil::ElementIsIntegral(src_shape) &&
- ShapeUtil::ElementIsIntegral(intermediate_shape);
-
- return is_conversion_floating || is_conversion_integral;
+ return src_type == convert->shape().element_type() &&
+ primitive_util::CastPreservesValues(src_type, intermediate_type);
}
PrecisionConfig SwapOperandsInDotPrecisionConfig(PrecisionConfig config) {
@@ -3226,8 +3207,7 @@
// convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
// $TYPE1) , floor(A), A) -> a case where the first convert has a
// fan-out
- if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
- IsConvertPairNoOp(convert)) {
+ if (IsConvertPairNoOp(convert)) {
return ReplaceInstruction(convert,
convert->mutable_operand(0)->mutable_operand(0));
}