Add canonicalization pattern for HLO_ScalarsToDimensionTensor.
PiperOrigin-RevId: 297848928
Change-Id: I4b2413c6105bf34b56d12aba24b65f0a5f972211
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index b011b60..f44bb9d 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -32,11 +32,13 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
+#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/OpDefinition.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
@@ -46,6 +48,7 @@
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
+#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
@@ -480,6 +483,48 @@
}
//===----------------------------------------------------------------------===//
+// ScalarsToDimensionTensorOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Canonicalizes the pattern of the form
+//
+// %2 = "xla_hlo.scalars_to_dimension_tensor"(%0, %1)
+// : (i32, i32) -> tensor<2xi32>
+// %3 = extract_element %2[%c0] : tensor<2xi32>
+//
+// to just %0.
+struct ExtractElementFromScalarsToDimensionTensor
+ : public OpRewritePattern<ExtractElementOp> {
+ using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ExtractElementOp extract,
+ PatternRewriter& rewriter) const override {
+ if (extract.indices().size() != 1) return matchFailure();
+
+ if (auto scalars_to_tensor = dyn_cast_or_null<ScalarsToDimensionTensorOp>(
+ extract.aggregate().getDefiningOp())) {
+ APInt index;
+ if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) {
+ return matchFailure();
+ }
+ rewriter.replaceOp(extract,
+ scalars_to_tensor.getOperand(index.getZExtValue()));
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
+} // namespace
+
+void ScalarsToDimensionTensorOp::getCanonicalizationPatterns(
+ OwningRewritePatternList& results, MLIRContext* context) {
+ results.insert<ExtractElementFromScalarsToDimensionTensor>(context);
+}
+
+//===----------------------------------------------------------------------===//
// DynamicBroadcastInDimOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 42b42d9..8fe7bb9 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -784,11 +784,12 @@
compute shape arguments to dynamic operations.
}];
- let arguments = (ins Variadic<AnySignlessInteger>);
+ let arguments = (ins Variadic<AnySignlessInteger>:$scalars);
let results = (outs HLO_DimensionTensor);
// Cannot be exported to legacy formats.
let hasCustomHLOConverter = 1;
+ let hasCanonicalizer = 1;
}
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index 2232063..b73cfcf 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -64,3 +64,13 @@
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
+
+// CHECK-LABEL: @extract_scalars_to_tensor
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 {
+ %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
+ %1 = constant 0 : index
+ %2 = extract_element %0[%1] : tensor<2xi32>
+ // CHECK: return %[[ARG0]]
+ return %2 : i32
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 7e2845d..ce70c78 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -453,6 +453,14 @@
// -----
+// CHECK-LABEL: @scalars_to_dimension_tensor
+func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> {
+ %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
+ return %0 : tensor<2xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @select
func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>