Concatenation zero-length operand removal
Remove any operand that is zero dimension. If any dimension is zero length, replace with a constant op.
PiperOrigin-RevId: 309123404
Change-Id: I7f1e56666693accfe3df8bdd35494047c69724ca
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index c846d27b..cea25f4 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -811,9 +811,53 @@
// ConcatenateOp
//===----------------------------------------------------------------------===//
+namespace {
+class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ConcatenateOp op,
+ PatternRewriter& rewriter) const override {
+ auto axis = op.dimension().getLimitedValue();
+ llvm::SmallVector<Value, 6> new_operands;
+ for (auto operand : op.getOperands()) {
+ auto ty = operand.getType().cast<ShapedType>();
+ if (ty.getDimSize(axis) != 0) {
+ new_operands.push_back(operand);
+ }
+ }
+
+ if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
+ rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
+ new_operands, op.dimension());
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void ConcatenateOp::getCanonicalizationPatterns(
+ OwningRewritePatternList& results, MLIRContext* context) {
+ results.insert<ConcatenateOperandRemoval>(context);
+}
+
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
if (getNumOperands() == 1) return getOperand(0);
- return {};
+
+ ShapedType type = getResult().getType().cast<ShapedType>();
+ if (!type.hasStaticShape()) return {};
+
+ auto axis = dimension().getLimitedValue();
+ llvm::SmallVector<Value, 6> new_operands;
+ for (auto operand : getOperands()) {
+ auto ty = operand.getType().cast<ShapedType>();
+ if (ty.getDimSize(axis) != 0) {
+ return {};
+ }
+ }
+
+ return DenseElementsAttr::get(type, ArrayRef<Attribute>());
}
static LogicalResult Verify(ConcatenateOp op) {
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 27c53d6..16c9a7b 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -847,6 +847,7 @@
let results = (outs HLO_Tensor);
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index 48645ff..5f28693 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -1,5 +1,50 @@
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
+// CHECK-LABEL: concatenate_noop
+func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
+ %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
+
+ // CHECK: return [[ARG]]
+ return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: concatenate_remove_operand
+func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
+ // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
+ // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
+
+ // CHECK: return [[ARG0]]
+ return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: concatenate_empty_bool
+func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
+ // CHECK: xla_hlo.constant
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+
+ return %0 : tensor<0xi1>
+}
+
+// CHECK-LABEL: concatenate_empty_int
+func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
+ // CHECK: xla_hlo.constant
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
+
+ return %0 : tensor<0xi32>
+}
+
+// CHECK-LABEL: concatenate_empty_float
+func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+ // CHECK: xla_hlo.constant
+ %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
+
+ return %0 : tensor<0xf32>
+}
+
+
+// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice"
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>