Updated concatenate op for unranked tensors

PiperOrigin-RevId: 316183263
Change-Id: I09511497e716aa69473bff416b5e458c6a3edb7b
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 3556dbf..38bff6c 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -836,21 +836,33 @@
   auto dimension = dimension_attr.getInt();
 
   auto first_type = (*operands.begin()).getType().cast<ShapedType>();
-
   auto out_element = first_type.getElementType();
+
+  for (auto operand : operands.getTypes()) {
+    auto element_type = getElementTypeOrSelf(operand);
+    if (element_type != out_element) {
+      return failure();
+    }
+  }
+
+  // If an input is unranked the output shape is unranked.
+  if (!first_type.hasRank()) {
+    inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
+    return success();
+  }
+
   auto out_shape = llvm::to_vector<6>(first_type.getShape());
   out_shape[dimension] = 0;
 
   for (auto operand : operands.getTypes()) {
     auto type = operand.cast<ShapedType>();
-    auto dim = type.getShape()[dimension];
-
-    // Validate the element types match.
-    if (type.getElementType() != out_element) {
-      return failure();
+    if (!type.hasRank()) {
+      inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
+      return success();
     }
 
     // If the dimension is dynamic we know the output dimension is dynamic.
+    auto dim = type.getShape()[dimension];
     if (dim == -1) {
       out_shape[dimension] = -1;
       break;
@@ -937,26 +949,39 @@
 }
 
 static LogicalResult Verify(ConcatenateOp op) {
-  auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
+  Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
+  RankedTensorType first_ranked_type;
+  int num_operands = op.getNumOperands();
+  for (int i = 0; i < num_operands; i++) {
+    auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
+    if (second_type.getElementType() != element_type) {
+      return op.emitOpError(
+          llvm::formatv("operands (0) and ({0}) do not match element type", i));
+    }
 
-  auto firstShape = firstType.getShape();
-  int numOperands = op.getNumOperands();
-  for (int i = 1; i < numOperands; i++) {
-    auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
+    if (!second_type.hasRank()) {
+      continue;
+    }
 
-    if (firstType.getRank() != secondType.getRank()) {
+    if (!first_ranked_type) {
+      first_ranked_type = second_type.cast<RankedTensorType>();
+      continue;
+    }
+
+    if (first_ranked_type.getRank() != second_type.getRank()) {
       return op.emitOpError(
           llvm::formatv("operands (0) and ({0}) do not match rank", i));
     }
 
-    auto secondShape = secondType.getShape();
-    for (int d = 0; d < firstType.getRank(); ++d) {
-      if (firstShape[d] != secondShape[d] && d != op.dimension()) {
+    auto first_shape = second_type.getShape();
+    auto second_shape = second_type.getShape();
+    for (int d = 0; d < first_ranked_type.getRank(); ++d) {
+      if (first_shape[d] != second_shape[d] && d != op.dimension()) {
         return op.emitOpError(llvm::formatv(
             "operands (0) and ({0}) non-concat dimensions do not match "
             "({1}) != ({2})",
-            i, llvm::make_range(firstShape.begin(), firstShape.end()),
-            llvm::make_range(secondShape.begin(), secondShape.end())));
+            i, llvm::make_range(first_shape.begin(), first_shape.end()),
+            llvm::make_range(second_shape.begin(), second_shape.end())));
       }
     }
   }
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 38964fb..2c68a0f 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -304,6 +304,46 @@
 
 // -----
 
+// CHECK-LABEL: @concat_1D
+func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>)  -> tensor<3xi32> {
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
+  return %0 : tensor<3xi32>
+}
+
+// -----
+
+func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>)  -> tensor<3xi32> {
+  // expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}}
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
+  return %0 : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concat_1D_unranked
+func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>)  -> tensor<*xi32> {
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
+// -----
+
+func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>)  -> tensor<3xi32> {
+  // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
+  return %0 : tensor<3xi32>
+}
+
+// -----
+
+func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>)  -> tensor<4xi32> {
+  // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @clamp
 func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
   %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>