[MLIR/XLA] Fix two bugs:
* Correctly handle trivial dimensions in TypeToShape.
* Correctly generate default comparison type in the XLA builder.

PiperOrigin-RevId: 336957760
Change-Id: I5e23797584cf670a994c63e7f3885e356e6bf053
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc
index 9aca3ce..3822e10 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc
@@ -139,7 +139,8 @@
       for (const auto& e : llvm::enumerate(strides)) {
         strides_with_indices.push_back({e.value(), e.index()});
       }
-      std::sort(strides_with_indices.begin(), strides_with_indices.end());
+      std::stable_sort(strides_with_indices.begin(),
+                       strides_with_indices.end());
 
       llvm::SmallVector<int64, 4> minor_to_major;
       int64_t stride = 1;
@@ -148,7 +149,7 @@
 
         // Either the affine map is not perfectly strided, or the dimensions
         // recovered from strides don't match the actual dimensions in shapes.
-        if (stride != pr.first) return {};
+        if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
 
         stride *= m.getShape()[pr.second];
       }
diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
index a4a2bc4..9741774 100644
--- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
+++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc
@@ -196,5 +196,22 @@
   EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
 }
 
+TEST(TypeToShapeTest, ConvertMemRefToShape2) {
+  Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3},
+                                               {2, 3, 1, 0});
+  MLIRContext context;
+  mlir::Builder builder(&context);
+
+  StatusOr<mlir::Type> mlir_type =
+      ConvertShapeToType<MemRefType>(shape, builder);
+  ASSERT_TRUE(mlir_type.ok());
+  mlir::Type type = mlir_type.ConsumeValueOrDie();
+  Shape converted = TypeToShape(type);
+  EXPECT_TRUE(ShapeUtil::Equal(
+      converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64,
+                                                {2, 4, 3, 3}, {2, 3, 1, 0})));
+  EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index c05b2c8..409cf37 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -257,6 +257,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 2ac3200..41212e6 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -691,8 +691,10 @@
 
 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
                                     ComparisonDirection direction) {
-  return Compare(shape, lhs, rhs, direction,
-                 Comparison::DefaultComparisonType(shape.element_type()));
+  TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
+  return Compare(
+      shape, lhs, rhs, direction,
+      Comparison::DefaultComparisonType(operand_shape.element_type()));
 }
 
 StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 7011c94..bfd13c8 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -19,6 +19,8 @@
 
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -1203,5 +1205,16 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
   ExpectInstructionsAttributesMatch(*module, expected);
 }
+
+TEST_F(XlaBuilderTest, ComparisonType) {
+  XlaBuilder b(TestName());
+  (void)Le(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
+  EXPECT_EQ(Comparison::Type::kSigned,
+            DynCast<HloCompareInstruction>(root)->type());
+}
+
 }  // namespace
 }  // namespace xla