Move Iota folding to xla_hlo-std legalization.
Iota op folding can lead to huge constants. Not every target would like to fold iota ops which increases the file size. Moving it to xla_hlo to standard legalization which was the original intent behind adding this fold.
PiperOrigin-RevId: 296282985
Change-Id: I71cb1679796ff0d36251ddb2f4cd0fce8aa75192
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 481c12b..41ef869 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -177,31 +177,6 @@
// IotaOp
//===----------------------------------------------------------------------===//
-OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
- const auto output_type = getResult().getType().cast<ShapedType>();
- const auto output_size = output_type.getNumElements();
- const auto dimension = iota_dimension().getSExtValue();
- const auto max_dim_size = output_type.getDimSize(dimension);
- int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
-
- llvm::SmallVector<APInt, 10> values;
- values.reserve(output_size);
-
- int64_t increase_stride = output_size;
- for (int i = 0; i <= dimension; i++) {
- increase_stride /= output_type.getDimSize(i);
- }
-
- int64_t current_value = 0;
- for (int i = 0; i < output_size; i++) {
- int64_t value = (current_value / increase_stride) % max_dim_size;
- values.push_back(APInt(bitwidth, value));
- ++current_value;
- }
-
- return DenseIntElementsAttr::get(output_type, values);
-}
-
static LogicalResult Verify(IotaOp op) {
auto shape = op.getType().cast<ShapedType>();
if (!shape.hasRank()) return success();
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 28c0a85..269e1cc 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -122,8 +122,6 @@
let results = (outs HLO_IntFpOrComplexTensor:$output);
- let hasFolder = 1;
-
// TODO(b/130357376): Iota has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index fa39b77..2232063 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -49,6 +49,14 @@
return %2 : tensor<4xcomplex<f32>>
}
+// CHECK-LABEL: @iota_not_lowered_to_constant
+func @iota_not_lowered_to_constant() -> tensor<4xi32> {
+ // CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
+ // CHECK: return [[RESULT]]
+ %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
diff --git a/tensorflow/compiler/mlir/xla/tests/iota.mlir b/tensorflow/compiler/mlir/xla/tests/iota.mlir
deleted file mode 100644
index 65b9f73..0000000
--- a/tensorflow/compiler/mlir/xla/tests/iota.mlir
+++ /dev/null
@@ -1,61 +0,0 @@
-// RUN: tf-opt %s -split-input-file -xla-legalize-to-std | FileCheck %s
-
-// -----
-
-// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
-func @iota.const.1() -> tensor<4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<4xi32>
- return %0 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
-func @iota.const.2() -> tensor<2x4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
- return %0 : tensor<2x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
-func @iota.const.3() -> tensor<2x4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
- return %0 : tensor<2x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
-func @iota.const.4() -> tensor<2x3x4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
- return %0 : tensor<2x3x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
-func @iota.const.5() -> tensor<2x3x4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
- return %0 : tensor<2x3x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
-func @iota.const.6() -> tensor<2x3x4xi32> {
- // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
- %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
- // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
- return %0 : tensor<2x3x4xi32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
index 1d2cf76..f56174a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
@@ -135,3 +135,51 @@
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
}
+// Test Iota lowering to constant
+// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
+func @iota.const.1() -> tensor<4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
+func @iota.const.2() -> tensor<2x4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
+ return %0 : tensor<2x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
+func @iota.const.3() -> tensor<2x4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
+ return %0 : tensor<2x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
+func @iota.const.4() -> tensor<2x3x4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+ return %0 : tensor<2x3x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
+func @iota.const.5() -> tensor<2x3x4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+ return %0 : tensor<2x3x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
+func @iota.const.6() -> tensor<2x3x4xi32> {
+ // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
+ %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
+ // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+ return %0 : tensor<2x3x4xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
index 9720d2a..5ee6010 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
@@ -105,6 +105,41 @@
}
};
+class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
+ PatternRewriter &rewriter) const override {
+ auto output_type = op.getType().cast<ShapedType>();
+ // TODO(prakalps): Handle FP and ComplexType iota ops.
+ if (!output_type.getElementType().isa<IntegerType>()) return matchFailure();
+ auto output_size = output_type.getNumElements();
+ auto dimension = op.iota_dimension().getSExtValue();
+ auto max_dim_size = output_type.getDimSize(dimension);
+ int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
+
+ llvm::SmallVector<APInt, 10> values;
+ values.reserve(output_size);
+
+ int64_t increase_stride = output_size;
+ for (int i = 0; i <= dimension; i++) {
+ increase_stride /= output_type.getDimSize(i);
+ }
+
+ int64_t current_value = 0;
+ for (int i = 0; i < output_size; i++) {
+ int64_t value = (current_value / increase_stride) % max_dim_size;
+ values.push_back(APInt(bitwidth, value));
+ ++current_value;
+ }
+
+ rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
+ op, DenseIntElementsAttr::get(output_type, values));
+ return matchSuccess();
+ }
+};
+
} // end anonymous namespace
namespace {
@@ -121,9 +156,7 @@
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, patterns);
- patterns
- ->insert<mlir::xla_hlo::CompareFConvert, mlir::xla_hlo::CompareIConvert>(
- ctx);
+ patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
}
/// Perform the lowering to standard dialect.