Updated concatenate op for unranked tensors
PiperOrigin-RevId: 316183263
Change-Id: I09511497e716aa69473bff416b5e458c6a3edb7b
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 3556dbf..38bff6c 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -836,21 +836,33 @@
auto dimension = dimension_attr.getInt();
auto first_type = (*operands.begin()).getType().cast<ShapedType>();
-
auto out_element = first_type.getElementType();
+
+ for (auto operand : operands.getTypes()) {
+ auto element_type = getElementTypeOrSelf(operand);
+ if (element_type != out_element) {
+ return failure();
+ }
+ }
+
+ // If an input is unranked the output shape is unranked.
+ if (!first_type.hasRank()) {
+ inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
+ return success();
+ }
+
auto out_shape = llvm::to_vector<6>(first_type.getShape());
out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) {
auto type = operand.cast<ShapedType>();
- auto dim = type.getShape()[dimension];
-
- // Validate the element types match.
- if (type.getElementType() != out_element) {
- return failure();
+ if (!type.hasRank()) {
+ inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
+ return success();
}
// If the dimension is dynamic we know the output dimension is dynamic.
+ auto dim = type.getShape()[dimension];
if (dim == -1) {
out_shape[dimension] = -1;
break;
@@ -937,26 +949,39 @@
}
static LogicalResult Verify(ConcatenateOp op) {
- auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
+ Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
+ RankedTensorType first_ranked_type;
+ int num_operands = op.getNumOperands();
+ for (int i = 0; i < num_operands; i++) {
+ auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
+ if (second_type.getElementType() != element_type) {
+ return op.emitOpError(
+ llvm::formatv("operands (0) and ({0}) do not match element type", i));
+ }
- auto firstShape = firstType.getShape();
- int numOperands = op.getNumOperands();
- for (int i = 1; i < numOperands; i++) {
- auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
+ if (!second_type.hasRank()) {
+ continue;
+ }
- if (firstType.getRank() != secondType.getRank()) {
+ if (!first_ranked_type) {
+ first_ranked_type = second_type.cast<RankedTensorType>();
+ continue;
+ }
+
+ if (first_ranked_type.getRank() != second_type.getRank()) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match rank", i));
}
- auto secondShape = secondType.getShape();
- for (int d = 0; d < firstType.getRank(); ++d) {
- if (firstShape[d] != secondShape[d] && d != op.dimension()) {
+ auto first_shape = second_type.getShape();
+ auto second_shape = second_type.getShape();
+ for (int d = 0; d < first_ranked_type.getRank(); ++d) {
+ if (first_shape[d] != second_shape[d] && d != op.dimension()) {
return op.emitOpError(llvm::formatv(
"operands (0) and ({0}) non-concat dimensions do not match "
"({1}) != ({2})",
- i, llvm::make_range(firstShape.begin(), firstShape.end()),
- llvm::make_range(secondShape.begin(), secondShape.end())));
+ i, llvm::make_range(first_shape.begin(), first_shape.end()),
+ llvm::make_range(second_shape.begin(), second_shape.end())));
}
}
}
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 38964fb..2c68a0f 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -304,6 +304,46 @@
// -----
+// CHECK-LABEL: @concat_1D
+func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
+ return %0 : tensor<3xi32>
+}
+
+// -----
+
+func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
+ // expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}}
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
+ return %0 : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @concat_1D_unranked
+func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
+
+// -----
+
+func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
+ // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
+ return %0 : tensor<3xi32>
+}
+
+// -----
+
+func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
+ // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @clamp
func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>