[MLIR/XLA] Fix two bugs:
* Correctly handle trivial dimensions in TypeToShape.
* Correctly generate default comparison type in the XLA builder.
PiperOrigin-RevId: 336957760
Change-Id: I5e23797584cf670a994c63e7f3885e356e6bf053
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc
index 9aca3ce..3822e10 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc
@@ -139,7 +139,8 @@
for (const auto& e : llvm::enumerate(strides)) {
strides_with_indices.push_back({e.value(), e.index()});
}
- std::sort(strides_with_indices.begin(), strides_with_indices.end());
+ std::stable_sort(strides_with_indices.begin(),
+ strides_with_indices.end());
llvm::SmallVector<int64, 4> minor_to_major;
int64_t stride = 1;
@@ -148,7 +149,7 @@
// Either the affine map is not perfectly strided, or the dimensions
// recovered from strides don't match the actual dimensions in shapes.
- if (stride != pr.first) return {};
+ if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
stride *= m.getShape()[pr.second];
}
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
index a4a2bc4..9741774 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
@@ -196,5 +196,22 @@
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
+TEST(TypeToShapeTest, ConvertMemRefToShape2) {
+ Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3},
+ {2, 3, 1, 0});
+ MLIRContext context;
+ mlir::Builder builder(&context);
+
+ StatusOr<mlir::Type> mlir_type =
+ ConvertShapeToType<MemRefType>(shape, builder);
+ ASSERT_TRUE(mlir_type.ok());
+ mlir::Type type = mlir_type.ConsumeValueOrDie();
+ Shape converted = TypeToShape(type);
+ EXPECT_TRUE(ShapeUtil::Equal(
+ converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64,
+ {2, 4, 3, 3}, {2, 3, 1, 0})));
+ EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index c05b2c8..409cf37 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -257,6 +257,7 @@
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 2ac3200..41212e6 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -691,8 +691,10 @@
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) {
- return Compare(shape, lhs, rhs, direction,
- Comparison::DefaultComparisonType(shape.element_type()));
+ TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
+ return Compare(
+ shape, lhs, rhs, direction,
+ Comparison::DefaultComparisonType(operand_shape.element_type()));
}
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 7011c94..bfd13c8 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -19,6 +19,8 @@
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -1203,5 +1205,16 @@
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
+
+TEST_F(XlaBuilderTest, ComparisonType) {
+ XlaBuilder b(TestName());
+ (void)Le(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
+ EXPECT_EQ(Comparison::Type::kSigned,
+ DynCast<HloCompareInstruction>(root)->type());
+}
+
} // namespace
} // namespace xla