Concatenation zero-length operand removal

Remove any operand that is zero dimension. If any dimension is zero length, replace with a constant op.

PiperOrigin-RevId: 309123404
Change-Id: I7f1e56666693accfe3df8bdd35494047c69724ca
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index c846d27b..cea25f4 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -811,9 +811,53 @@
 // ConcatenateOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(ConcatenateOp op,
+                                PatternRewriter& rewriter) const override {
+    auto axis = op.dimension().getLimitedValue();
+    llvm::SmallVector<Value, 6> new_operands;
+    for (auto operand : op.getOperands()) {
+      auto ty = operand.getType().cast<ShapedType>();
+      if (ty.getDimSize(axis) != 0) {
+        new_operands.push_back(operand);
+      }
+    }
+
+    if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
+                                                 new_operands, op.dimension());
+      return success();
+    }
+
+    return failure();
+  }
+};
+}  // namespace
+
+void ConcatenateOp::getCanonicalizationPatterns(
+    OwningRewritePatternList& results, MLIRContext* context) {
+  results.insert<ConcatenateOperandRemoval>(context);
+}
+
 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
   if (getNumOperands() == 1) return getOperand(0);
-  return {};
+
+  ShapedType type = getResult().getType().cast<ShapedType>();
+  if (!type.hasStaticShape()) return {};
+
+  auto axis = dimension().getLimitedValue();
+  llvm::SmallVector<Value, 6> new_operands;
+  for (auto operand : getOperands()) {
+    auto ty = operand.getType().cast<ShapedType>();
+    if (ty.getDimSize(axis) != 0) {
+      return {};
+    }
+  }
+
+  return DenseElementsAttr::get(type, ArrayRef<Attribute>());
 }
 
 static LogicalResult Verify(ConcatenateOp op) {
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 27c53d6..16c9a7b 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -847,6 +847,7 @@
 
   let results = (outs HLO_Tensor);
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index 48645ff..5f28693 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -1,5 +1,50 @@
 // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
 
+// CHECK-LABEL: concatenate_noop
+func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+  // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
+  %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
+
+  // CHECK: return [[ARG]]
+  return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: concatenate_remove_operand
+func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
+  // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
+  // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
+
+  // CHECK: return [[ARG0]]
+  return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: concatenate_empty_bool
+func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
+  // CHECK: xla_hlo.constant
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+
+  return %0 : tensor<0xi1>
+}
+
+// CHECK-LABEL: concatenate_empty_int
+func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
+  // CHECK: xla_hlo.constant
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
+
+  return %0 : tensor<0xi32>
+}
+
+// CHECK-LABEL: concatenate_empty_float
+func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+  // CHECK: xla_hlo.constant
+  %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
+
+  return %0 : tensor<0xf32>
+}
+
+
+// CHECK-LABEL: dynamic_slice_variable_start
 func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
   // CHECK: "xla_hlo.dynamic-slice"
   %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>