Move Iota folding to xla_hlo-std legalization.

Iota op folding can lead to huge constants. Not every target would like to fold iota ops which increases the file size. Moving it to xla_hlo to standard legalization which was the original intent behind adding this fold.

PiperOrigin-RevId: 296282985
Change-Id: I71cb1679796ff0d36251ddb2f4cd0fce8aa75192
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 481c12b..41ef869 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -177,31 +177,6 @@
 // IotaOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
-  const auto output_type = getResult().getType().cast<ShapedType>();
-  const auto output_size = output_type.getNumElements();
-  const auto dimension = iota_dimension().getSExtValue();
-  const auto max_dim_size = output_type.getDimSize(dimension);
-  int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
-
-  llvm::SmallVector<APInt, 10> values;
-  values.reserve(output_size);
-
-  int64_t increase_stride = output_size;
-  for (int i = 0; i <= dimension; i++) {
-    increase_stride /= output_type.getDimSize(i);
-  }
-
-  int64_t current_value = 0;
-  for (int i = 0; i < output_size; i++) {
-    int64_t value = (current_value / increase_stride) % max_dim_size;
-    values.push_back(APInt(bitwidth, value));
-    ++current_value;
-  }
-
-  return DenseIntElementsAttr::get(output_type, values);
-}
-
 static LogicalResult Verify(IotaOp op) {
   auto shape = op.getType().cast<ShapedType>();
   if (!shape.hasRank()) return success();
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 28c0a85..269e1cc 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -122,8 +122,6 @@
 
   let results = (outs HLO_IntFpOrComplexTensor:$output);
 
-  let hasFolder = 1;
-
   // TODO(b/130357376): Iota has special conversion logic to HLO.
   let hasCustomHLOConverter = 1;
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index fa39b77..2232063 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -49,6 +49,14 @@
   return %2 : tensor<4xcomplex<f32>>
 }
 
+// CHECK-LABEL: @iota_not_lowered_to_constant
+func @iota_not_lowered_to_constant() -> tensor<4xi32> {
+  // CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
+  // CHECK: return [[RESULT]]
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
 // CHECK-LABEL: @unary_einsum
 func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
   // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
diff --git a/tensorflow/compiler/mlir/xla/tests/iota.mlir b/tensorflow/compiler/mlir/xla/tests/iota.mlir
deleted file mode 100644
index 65b9f73..0000000
--- a/tensorflow/compiler/mlir/xla/tests/iota.mlir
+++ /dev/null
@@ -1,61 +0,0 @@
-// RUN: tf-opt %s -split-input-file -xla-legalize-to-std | FileCheck %s
-
-// -----
-
-// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
-func @iota.const.1() -> tensor<4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<4xi32>
-  return %0 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
-func @iota.const.2() -> tensor<2x4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
-  return %0 : tensor<2x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
-func @iota.const.3() -> tensor<2x4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
-  return %0 : tensor<2x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
-func @iota.const.4() -> tensor<2x3x4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
-  return %0 : tensor<2x3x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
-func @iota.const.5() -> tensor<2x3x4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
-  return %0 : tensor<2x3x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
-func @iota.const.6() -> tensor<2x3x4xi32> {
-  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
-  %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
-  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
-  return %0 : tensor<2x3x4xi32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
index 1d2cf76..f56174a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
@@ -135,3 +135,51 @@
   return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
 }
 
+// Test Iota lowering to constant
+// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
+func @iota.const.1() -> tensor<4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
+func @iota.const.2() -> tensor<2x4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
+  return %0 : tensor<2x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
+func @iota.const.3() -> tensor<2x4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
+  return %0 : tensor<2x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
+func @iota.const.4() -> tensor<2x3x4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+  return %0 : tensor<2x3x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
+func @iota.const.5() -> tensor<2x3x4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+  return %0 : tensor<2x3x4xi32>
+}
+
+// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
+func @iota.const.6() -> tensor<2x3x4xi32> {
+  // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
+  %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
+  // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
+  return %0 : tensor<2x3x4xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
index 9720d2a..5ee6010 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc
@@ -105,6 +105,41 @@
   }
 };
 
+class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto output_type = op.getType().cast<ShapedType>();
+    // TODO(prakalps): Handle FP and ComplexType iota ops.
+    if (!output_type.getElementType().isa<IntegerType>()) return matchFailure();
+    auto output_size = output_type.getNumElements();
+    auto dimension = op.iota_dimension().getSExtValue();
+    auto max_dim_size = output_type.getDimSize(dimension);
+    int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
+
+    llvm::SmallVector<APInt, 10> values;
+    values.reserve(output_size);
+
+    int64_t increase_stride = output_size;
+    for (int i = 0; i <= dimension; i++) {
+      increase_stride /= output_type.getDimSize(i);
+    }
+
+    int64_t current_value = 0;
+    for (int i = 0; i < output_size; i++) {
+      int64_t value = (current_value / increase_stride) % max_dim_size;
+      values.push_back(APInt(bitwidth, value));
+      ++current_value;
+    }
+
+    rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
+        op, DenseIntElementsAttr::get(output_type, values));
+    return matchSuccess();
+  }
+};
+
 }  // end anonymous namespace
 
 namespace {
@@ -121,9 +156,7 @@
 void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
                               mlir::MLIRContext *ctx) {
   mlir::populateWithGenerated(ctx, patterns);
-  patterns
-      ->insert<mlir::xla_hlo::CompareFConvert, mlir::xla_hlo::CompareIConvert>(
-          ctx);
+  patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
 }
 
 /// Perform the lowering to standard dialect.