Canonicalize MHLO Case and If Ops with constant conditions
ReplaceOpWithRegion was taken directly from ScfOps. We should maybe put that somewhere common in core.
PiperOrigin-RevId: 365936724
Change-Id: Ibf06126dbb44219265472abb820320d718437484
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 122a90c..efb041a 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -549,6 +549,8 @@
// TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
+
+ let hasCanonicalizer = 1;
}
// Xla Client API has two separate calls for indexed and predicated conditional,
@@ -569,6 +571,8 @@
let results = (outs Variadic<HLO_TensorOrTuple>);
let hasCustomHLOConverter = 1;
+
+ let hasCanonicalizer = 1;
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index bdc296d..1ccfed8 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
@@ -131,6 +132,19 @@
return GetI64ElementsAttr(slice_limits, builder);
}
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
+ Region& region, ValueRange blockArgs = {}) {
+ assert(llvm::hasSingleElement(region) && "expected single-block region");
+ Block* block = ®ion.front();
+ Operation* terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.mergeBlockBefore(block, op, blockArgs);
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+}
+
#include "mhlo_canonicalize.inc"
} // namespace
@@ -2129,6 +2143,24 @@
return success();
}
+static LogicalResult InlineIfConstantCondition(IfOp ifOp,
+ PatternRewriter& rewriter) {
+ DenseIntElementsAttr pred_attr;
+ if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure();
+
+ if (pred_attr.getSplatValue<BoolAttr>().getValue()) {
+ ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch(), ifOp.true_arg());
+ } else {
+ ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch(), ifOp.false_arg());
+ }
+ return success();
+}
+
+void IfOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
+ MLIRContext* context) {
+ results.add(&InlineIfConstantCondition);
+}
+
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//
@@ -2150,6 +2182,31 @@
return success();
}
+static LogicalResult InlineCaseConstantCondition(CaseOp caseOp,
+ PatternRewriter& rewriter) {
+ DenseIntElementsAttr index_attr;
+ if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) {
+ return failure();
+ }
+ int64_t index =
+ index_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
+ // For an OOB index, the last branch is executed as the default branch:
+ // https://www.tensorflow.org/xla/operation_semantics#conditional
+ if (index < 0 || index >= caseOp.getNumRegions())
+ index = caseOp.getNumRegions() - 1;
+
+ Region& region = caseOp.getRegion(index);
+ if (!llvm::hasSingleElement(region)) return failure();
+ ReplaceOpWithRegion(rewriter, caseOp, region,
+ caseOp.branch_operands()[index]);
+ return success();
+}
+
+void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
+ MLIRContext* context) {
+ results.add(&InlineCaseConstantCondition);
+}
+
//===----------------------------------------------------------------------===//
// SqrtOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
index c234d32..70a86e1 100644
--- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
@@ -1280,6 +1280,108 @@
return %1 : tensor<4xf32>
}
+// CHECK-LABEL: func @fold_if_true(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: )
+func @fold_if_true(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NOT: mhlo.if
+ // CHECK: return %[[ARG0]]
+ %true = mhlo.constant dense<true> : tensor<i1>
+ %0 = "mhlo.if"(%true, %arg0, %arg1) ( {
+ ^bb0(%bbarg0: tensor<f32>):
+ "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg1: tensor<f32>):
+ "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @fold_if_false(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: )
+func @fold_if_false(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NOT: mhlo.if
+ // CHECK: return %[[ARG1]]
+ %false = mhlo.constant dense<false> : tensor<i1>
+ %0 = "mhlo.if"(%false, %arg0, %arg1) ( {
+ ^bb0(%bbarg0: tensor<f32>):
+ "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg1: tensor<f32>):
+ "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @fold_case(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: )
+func @fold_case(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NOT: mhlo.case
+ // CHECK: return %[[ARG1]]
+ %c1 = mhlo.constant dense<1> : tensor<i32>
+ %0 = "mhlo.case"(%c1, %arg0, %arg1, %arg2) ( {
+ ^bb0(%bbarg0: tensor<f32>):
+ "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg1: tensor<f32>):
+ "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg2: tensor<f32>):
+ "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @fold_case_negative_index(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: )
+func @fold_case_negative_index(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NOT: mhlo.case
+ // CHECK: return %[[ARG2]]
+ %m1000 = mhlo.constant dense<-1000> : tensor<i32>
+ %0 = "mhlo.case"(%m1000, %arg0, %arg1, %arg2) ( {
+ ^bb0(%bbarg0: tensor<f32>):
+ "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg1: tensor<f32>):
+ "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg2: tensor<f32>):
+ "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @fold_case_oob_index(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: )
+func @fold_case_oob_index(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NOT: mhlo.case
+ // CHECK: return %[[ARG2]]
+ %c1000 = mhlo.constant dense<1000> : tensor<i32>
+ %0 = "mhlo.case"(%c1000, %arg0, %arg1, %arg2) ( {
+ ^bb0(%bbarg0: tensor<f32>):
+ "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg1: tensor<f32>):
+ "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%bbarg2: tensor<f32>):
+ "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
// CHECK-LABEL: @tensor_flow_scatter_v1_update
func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> {
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>