[ST] Add loop peeling pass for gml_st.loop.

PiperOrigin-RevId: 426982579
Change-Id: I10691615a84d8d66ea3ca5dbd89d44eef0182ebb
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index cd0b431..1300d73 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -1567,6 +1567,38 @@
 )
 
 gentbl_cc_library(
+    name = "gml_st_test_passes_inc_gen",
+    compatible_with = get_compatible_with_cloud(),
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=GmlStTest",
+            ],
+            "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td",
+    deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
+cc_library(
+    name = "all_test_passes",
+    srcs = ["lib/Dialect/gml_st/transforms/test_passes.cc"],
+    hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h"],
+    includes = ["include"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":gml_st_test_passes_inc_gen",
+        ":gml_st_transforms",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+gentbl_cc_library(
     name = "transforms_pass_inc_gen",
     compatible_with = get_compatible_with_cloud(),
     strip_include_prefix = "include",
@@ -1766,11 +1798,10 @@
 
 cc_binary(
     name = "mlir-hlo-opt",
-    srcs = [
-        "tools/mlir-hlo-opt/mlir-hlo-opt.cpp",
-    ],
+    srcs = ["tools/mlir-hlo-opt/mlir-hlo-opt.cpp"],
     deps = [
         ":all_passes",
+        ":all_test_passes",
         ":gml_st",
         ":hlo_dialect_registration",
         ":lhlo",
@@ -2017,3 +2048,17 @@
         "@llvm-project//mlir:Transforms",
     ],
 )
+
+cc_library(
+    name = "gml_st_transforms",
+    srcs = ["lib/Dialect/gml_st/transforms/transforms.cc"],
+    hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/transforms.h"],
+    includes = ["include"],
+    deps = [
+        ":gml_st",
+        "@llvm-project//mlir:Affine",
+        "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:SCFUtils",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
index 44c079e..f6a9948 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
@@ -18,6 +18,10 @@
 mlir_tablegen(passes.h.inc -gen-pass-decls -name GmlSt)
 add_public_tablegen_target(MLIRGmlStPassIncGen)
 
+set(LLVM_TARGET_DEFINITIONS test_passes.td)
+mlir_tablegen(test_passes.h.inc -gen-pass-decls -name GmlStTest)
+add_public_tablegen_target(MLIRGmlStTestPassIncGen)
+
 set(LLVM_TARGET_DEFINITIONS tiling_interface.td)
 mlir_tablegen(tiling_interface.h.inc -gen-op-interface-decls)
 mlir_tablegen(tiling_interface.cc.inc -gen-op-interface-defs)
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h
new file mode 100644
index 0000000..b4dc704
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h
@@ -0,0 +1,32 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace gml_st {
+
+std::unique_ptr<OperationPass<FuncOp>> createTestGmlStLoopPeelingPass();
+
+#define GEN_PASS_REGISTRATION
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc"
+
+}  // namespace gml_st
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td
new file mode 100644
index 0000000..24f5892
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td
@@ -0,0 +1,27 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/Pass/PassBase.td"
+
+def TestGmlStLoopPeeling : Pass<"test-gml-st-loop-peeling", "mlir::FuncOp"> {
+  let summary = "Peel `gml_st.loop`";
+  let constructor = "::mlir::gml_st::createTestGmlStLoopPeelingPass()";
+  let options = [
+    Option<"skip_partial", "skip-partial", "bool", /*default=*/"false",
+           "Skip loops inside partial iterations during peeling">,
+    ListOption<"dims", "dims", "unsigned", "Dimensions to peel",
+           "llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+  ];
+}
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h
new file mode 100644
index 0000000..f71e44e
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h
@@ -0,0 +1,54 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace gml_st {
+
+/// Rewrite a gml_st::LoopOp with bounds/step that potentially do not divide
+/// evenly into a gml_st::LoopOp where the step divides the iteration space
+/// evenly, followed by another gml_st::LoopOp for the last (partial) iteration
+/// (if any). This transformation is called "loop peeling".
+///
+/// This function peels the `idx`-th loop of the gml_st::LoopOp. To tile all
+/// loops in the loop nest, this function must be called multiple times.
+///
+/// After loop peeling, this function tries to simplify/canonicalize affine.min
+/// and affine.max ops in the body of the two gml_st::LoopOps. For more details,
+/// refer to `mlir::scf::peelAndCanonicalizeForLoop`.
+///
+/// The return value indicates whether the loop was rewritten or not. Loops are
+/// not rewritten if:
+/// * Loop step size is 1 or
+/// * Loop bounds and step size are static, and step already divides the
+///   iteration space evenly.
+///
+/// Note: This function rewrites the given gml_st::LoopOp in-place and clones
+/// the gml_st::LoopOp operation for the last iteration. It replaces all uses of
+/// the unpeeled gml_st::LoopOp with the results of the newly generated
+/// gml_st::LoopOp.
+LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
+                                           LoopOp loopOp, int64_t idx,
+                                           LoopOp &result);
+
+}  // namespace gml_st
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
index 9352d40..9484050 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
@@ -1,4 +1,3 @@
-#
 # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -45,3 +44,31 @@
   MLIRPass
   MLIRSupport
 )
+
+add_mlir_library(GmlStTransforms
+  transforms.cc
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRAffine
+  MLIRDialectUtils
+  MLIRIR
+  MLIRSCFUtils
+)
+
+add_mlir_library(GmlStTestPasses
+  test_passes.cc
+
+  DEPENDS
+  MLIRGmlStTestPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  GmlStTransforms
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc
new file mode 100644
index 0000000..1a83b92
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc
@@ -0,0 +1,109 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h"
+
+#include <utility>
+
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc"
+
+static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
+static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
+
+/// Peel LoopOps, i.e., split them into two loops: One loop where the
+/// `idx`-th loop contains only "full" iterations and a second loop for the
+/// remaining partial iteration (if any).
+struct TiledLoopPeelingPattern : public OpRewritePattern<LoopOp> {
+  TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skip_partial)
+      : OpRewritePattern<LoopOp>(ctx), idx(idx), skip_partial(skip_partial) {}
+
+  LogicalResult matchAndRewrite(LoopOp loopOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<int64_t> peeledLoops;
+    if (loopOp->hasAttr(kPeeledLoopsLabel)) {
+      auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
+      peeledLoops =
+          llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
+            return attr.cast<IntegerAttr>().getInt();
+          }));
+      // Check if the loop was already peeled.
+      if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure();
+    }
+    if (skip_partial && loopOp->hasAttr(kPartialIterationLabel))
+      // No peeling of loop nests with a partial iteration.
+      return failure();
+
+    if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
+      return failure();
+
+    // Peel loop and canonicalize.
+    LoopOp result;
+    if (failed(peelAndCanonicalizeGmlStLoop(rewriter, loopOp, idx, result)))
+      return failure();
+
+    // Apply label, so that the same loop is not rewritten a second time.
+    peeledLoops.push_back(idx);
+    rewriter.updateRootInPlace(loopOp, [&]() {
+      loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
+    });
+    result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
+    result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+
+    return success();
+  }
+
+  /// Index of loop to peel.
+  int64_t idx;
+
+  /// If set to true, do not peel LoopOps with a partial iteration.
+  bool skip_partial;
+};
+
+class TestGmlStLoopPeelingPass
+    : public TestGmlStLoopPeelingBase<TestGmlStLoopPeelingPass> {
+  void runOnOperation() final {
+    auto funcOp = getOperation();
+    MLIRContext *ctx = funcOp.getContext();
+    RewritePatternSet patterns(ctx);
+    for (unsigned idx : dims)
+      patterns.add<TiledLoopPeelingPattern>(ctx, idx, skip_partial);
+
+    (void)(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)));
+
+    // Drop the markers.
+    funcOp.walk([](LoopOp op) {
+      op->removeAttr(kPeeledLoopsLabel);
+      op->removeAttr(kPartialIterationLabel);
+    });
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createTestGmlStLoopPeelingPass() {
+  return std::make_unique<TestGmlStLoopPeelingPass>();
+}
+
+}  // namespace gml_st
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc
new file mode 100644
index 0000000..7f16b0a
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc
@@ -0,0 +1,136 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+/// Rewrite a LoopOp with bounds/step that potentially do not divide evenly
+/// into two LoopOps: One where the step divides the iteration space
+/// evenly, followed another one for the last (partial) iteration (if any). This
+/// function only rewrites the `idx`-th loop of the loop nest represented by
+/// the LoopOp. To peel the entire loop nest, this function must be called
+/// multiple times.
+///
+/// This function rewrites the given LoopOp in-place and creates a new
+/// LoopOp for the last iteration. It replaces all uses of the original
+/// LoopOp with the results of the newly generated one.
+///
+/// The newly generated LoopOp is returned via `result`. The boundary
+/// at which the loop is split (new upper bound) is returned via `splitBound`.
+/// The return value indicates whether the LoopOp was rewritten or not.
+static LogicalResult peelLoop(RewriterBase &b, LoopOp loopOp, int64_t idx,
+                              LoopOp &result, Value &splitBound) {
+  Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
+        step = loopOp.step()[idx];
+  auto ubInt = getConstantIntValue(ub);
+
+  auto loc = loopOp.getLoc();
+  AffineExpr exprLb, exprUb, exprStep;
+  bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
+  // New upper bound: %ub - (%ub - %lb) mod %step
+  auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
+  SmallVector<Value> operands{lb, ub, step};
+  canonicalizeMapAndOperands(&modMap, &operands);
+  modMap = simplifyAffineMap(modMap);
+  RewriterBase::InsertionGuard guard(b);
+  b.setInsertionPoint(loopOp);
+  splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
+  // No specialization necessary if step already divides upper bound evenly.
+  if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
+    return failure();
+
+  // Create remainder loop.
+  b.setInsertionPointAfter(loopOp);
+  auto remainderLoop = cast<LoopOp>(b.clone(*loopOp.getOperation()));
+  loopOp.replaceAllUsesWith(remainderLoop->getResults());
+  // Outputs: Take tensors from main loop's results. Take memrefs from main
+  // loop's outputs.
+  SmallVector<Value> remainderOutputs;
+  for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
+    remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
+                                   ? loopOp.outputs()[o]
+                                   : loopOp->getResult(t++));
+  }
+  remainderLoop.outputsMutable().assign(remainderOutputs);
+
+  // Set new loop bounds.
+  b.updateRootInPlace(loopOp, [&]() {
+    SmallVector<Value> ubs = loopOp.upperBound();
+    ubs[idx] = splitBound;
+    loopOp.upperBoundMutable().assign(ubs);
+  });
+  SmallVector<Value> lbs = remainderLoop.lowerBound();
+  lbs[idx] = splitBound;
+  remainderLoop.lowerBoundMutable().assign(lbs);
+
+  result = remainderLoop;
+  return success();
+}
+
+template <typename OpTy, bool IsMin>
+static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, LoopOp mainLoop,
+                                        LoopOp remainderLoop, Value mainIv,
+                                        Value remainderIv, Value ub,
+                                        Value step) {
+  mainLoop.walk([&](OpTy affineOp) {
+    AffineMap map = affineOp.getAffineMap();
+    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
+                                     affineOp.operands(), IsMin, mainIv, ub,
+                                     step, /*insideLoop=*/true);
+  });
+  remainderLoop.walk([&](OpTy affineOp) {
+    AffineMap map = affineOp.getAffineMap();
+    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
+                                     affineOp.operands(), IsMin, remainderIv,
+                                     ub, step, /*insideLoop=*/false);
+  });
+}
+
+}  // namespace
+
+LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
+                                           LoopOp loopOp, int64_t idx,
+                                           LoopOp &result) {
+  int64_t numLoops = loopOp.iterator_types().size();
+  if (idx < 0 || numLoops <= idx) return failure();
+
+  Value ub = loopOp.upperBound()[idx];
+  LoopOp remainderLoop;
+  Value splitBound;
+  if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
+    return failure();
+
+  // Rewrite affine.min and affine.max ops.
+  Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
+        remainderIv = remainderLoop.getInductionVars()[idx];
+
+  rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
+      rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
+  rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
+      rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
+
+  result = remainderLoop;
+  return success();
+}
+
+}  // namespace gml_st
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir
new file mode 100644
index 0000000..b85afba
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir
@@ -0,0 +1,231 @@
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2 skip-partial" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012-SKIP-PARTIAL
+
+// CHECK-TILE-2-LABEL: func @loop_3d_tensor(
+//  CHECK-TILE-2-SAME:     %[[input:.*]]: tensor<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-TILE-2-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-TILE-2-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//   CHECK-TILE-2-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK-TILE-2:   %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
+//       CHECK-TILE-2:   %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
+//       CHECK-TILE-2:   %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
+//       CHECK-TILE-2:   %[[init_tensor:.*]] = linalg.init_tensor
+//       CHECK-TILE-2:   %[[split_bound:.*]] = affine.apply
+//       CHECK-TILE-2:   %[[r1:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
+//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[split_bound]])
+//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
+//  CHECK-TILE-2-SAME:       ins (%[[loop_in1:.*]] = %[[input]]: tensor<?x?x?xf32>)
+//  CHECK-TILE-2-SAME:       outs (%[[loop_out1:.*]] = %[[init_tensor]]: tensor<?x?x?xf32>) {
+//       CHECK-TILE-2:     %[[min0_1:.*]] = affine.min
+//       CHECK-TILE-2:     %[[min1_1:.*]] = affine.min
+//       CHECK-TILE-2:     %[[in_slice1:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+//       CHECK-TILE-2:     %[[out_slice1:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+//       CHECK-TILE-2:     %[[mod_slice1:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+//       CHECK-TILE-2:     gml_st.yield %[[mod_slice1]]
+//       CHECK-TILE-2:   %[[r2:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
+//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[dim2]])
+//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
+//  CHECK-TILE-2-SAME:       ins (%[[loop_in2:.*]] = %[[input]]: tensor<?x?x?xf32>)
+//  CHECK-TILE-2-SAME:       outs (%[[loop_out2:.*]] = %[[r1]]: tensor<?x?x?xf32>) {
+//       CHECK-TILE-2:     %[[min0_2:.*]] = affine.min
+//       CHECK-TILE-2:     %[[min1_2:.*]] = affine.min
+//       CHECK-TILE-2:     %[[apply2:.*]] = affine.apply
+//       CHECK-TILE-2:     %[[in_slice2:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+//       CHECK-TILE-2:     %[[out_slice2:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+//       CHECK-TILE-2:     %[[mod_slice2:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+//       CHECK-TILE-2:     gml_st.yield %[[mod_slice2]]
+//       CHECK-TILE-2:   return %[[r2]]
+
+// CHECK-TILE-012-LABEL: func @loop_3d_tensor
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//       CHECK-TILE-012:   gml_st.loop {{.*}} {
+//       CHECK-TILE-012:     gml_st.yield
+//       CHECK-TILE-012:   }
+//   CHECK-TILE-012-NOT: gml_st.loop
+
+//      CHECK-TILE-012-SKIP-PARTIAL: func @loop_3d_tensor(
+// CHECK-TILE-012-SKIP-PARTIAL-SAME:     %[[input:.*]]: tensor<?x?x?xf32>
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
+//  CHECK-TILE-012-SKIP-PARTIAL-DAG:   %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
+//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p0:.*]] = affine.apply #{{.*}}()[%[[dim0]]
+//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p1:.*]] = affine.apply #{{.*}}()[%[[dim1]]
+//      CHECK-TILE-012-SKIP-PARTIAL:   %[[p2:.*]] = affine.apply #{{.*}}()[%[[dim2]]
+//      CHECK-TILE-012-SKIP-PARTIAL:   gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[c0]]) to (%[[p0]], %[[p1]], %[[p2]])
+//      CHECK-TILE-012-SKIP-PARTIAL:   gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[p2]]) to (%[[p0]], %[[p1]], %[[dim2]])
+//      CHECK-TILE-012-SKIP-PARTIAL:   gml_st.loop {{.*}} = (%[[c0]], %[[p1]], %[[c0]]) to (%[[p0]], %[[dim1]], %[[dim2]])
+//      CHECK-TILE-012-SKIP-PARTIAL:   gml_st.loop {{.*}} = (%[[p0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim1]], %[[dim2]])
+func @loop_3d_tensor(%arg0: tensor<?x?x?xf32>, %s0: index, %s1: index,
+                           %s2: index) -> tensor<?x?x?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c8 = arith.constant 8 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+  %result = gml_st.loop
+           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+           step (%s0, %s1, %s2) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+           outs (%arg5 = %output: tensor<?x?x?xf32>) {
+    %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
+    %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
+    %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
+    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+    gml_st.yield %updated_slice : tensor<?x?x?xf32>
+  }
+  return %result : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @loop_3d_memref(
+//  CHECK-TILE-2-SAME:     %[[input:.*]]: memref<?x?x?xf32>, %[[output:.*]]: memref<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-TILE-2-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-TILE-2-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//   CHECK-TILE-2-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK-TILE-2:   %[[dim0:.*]] = memref.dim %[[input]], %[[c0]]
+//       CHECK-TILE-2:   %[[dim1:.*]] = memref.dim %[[input]], %[[c1]]
+//       CHECK-TILE-2:   %[[dim2:.*]] = memref.dim %[[input]], %[[c2]]
+//       CHECK-TILE-2:   %[[split_bound:.*]] = affine.apply
+//       CHECK-TILE-2:   gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
+//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[split_bound]])
+//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
+//  CHECK-TILE-2-SAME:       ins (%[[loop_in1:.*]] = %[[input]]: memref<?x?x?xf32>)
+//  CHECK-TILE-2-SAME:       outs (%[[loop_out1:.*]] = %[[output]]: memref<?x?x?xf32>) {
+//       CHECK-TILE-2:     %[[min0_1:.*]] = affine.min
+//       CHECK-TILE-2:     %[[min1_1:.*]] = affine.min
+//       CHECK-TILE-2:     memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+//       CHECK-TILE-2:     gml_st.yield
+//       CHECK-TILE-2:   gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
+//  CHECK-TILE-2-SAME:       to (%[[dim0]], %[[dim1]], %[[dim2]])
+//  CHECK-TILE-2-SAME:       step (%[[s0]], %[[s1]], %[[s2]])
+//  CHECK-TILE-2-SAME:       ins (%[[loop_in2:.*]] = %[[input]]: memref<?x?x?xf32>)
+//  CHECK-TILE-2-SAME:       outs (%[[loop_out2:.*]] = %[[output]]: memref<?x?x?xf32>) {
+//       CHECK-TILE-2:     %[[min0_2:.*]] = affine.min
+//       CHECK-TILE-2:     %[[min1_2:.*]] = affine.min
+//       CHECK-TILE-2:     %[[apply2:.*]] = affine.apply
+//       CHECK-TILE-2:     memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+//       CHECK-TILE-2:     gml_st.yield
+//       CHECK-TILE-2:   return
+
+// CHECK-TILE-012-LABEL: func @loop_3d_memref
+
+!memref_subview_type = type memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>
+
+func @loop_3d_memref(%arg0: memref<?x?x?xf32>, %output: memref<?x?x?xf32>,
+                           %s0: index, %s1: index, %s2: index) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c8 = arith.constant 8 : index
+  %dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+  %dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+  %dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+  gml_st.loop
+           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+           step (%s0, %s1, %s2) ins (%arg4 = %arg0: memref<?x?x?xf32>)
+           outs (%arg5 = %output : memref<?x?x?xf32>) {
+    %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
+    %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
+    %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
+    %in_slice = memref.subview %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: memref<?x?x?xf32> to !memref_subview_type
+    "computation"(%in_slice) : (!memref_subview_type) -> memref<?x?x?xf32>
+    gml_st.yield
+  }
+  return
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @step_1_do_not_peel
+//       CHECK-TILE-2:   gml_st.loop
+//   CHECK-TILE-2-NOT:   gml_st.loop
+
+// CHECK-TILE-012-LABEL: func @step_1_do_not_peel
+
+func @step_1_do_not_peel(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c8 = arith.constant 8 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+  %result = gml_st.loop
+           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+           step (%c1, %c1, %c1) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+           outs (%arg5 = %output: tensor<?x?x?xf32>) {
+    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+    gml_st.yield %updated_slice : tensor<?x?x?xf32>
+  }
+  return %result : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @divides_evenly_do_not_peel
+//       CHECK-TILE-2:   gml_st.loop
+//   CHECK-TILE-2-NOT:   gml_st.loop
+
+// CHECK-TILE-012-LABEL: func @divides_evenly_do_not_peel
+
+func @divides_evenly_do_not_peel(%arg0: tensor<?x?x?xf32>, %s: index)
+    -> tensor<?x?x?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c8 = arith.constant 8 : index
+  %c64 = arith.constant 64 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+  %result = gml_st.loop
+           (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %c64)
+           step (%s, %s, %c8) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+           outs (%arg5 = %output: tensor<?x?x?xf32>) {
+    %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+    %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+    gml_st.yield %updated_slice : tensor<?x?x?xf32>
+  }
+  return %result : tensor<?x?x?xf32>
+}
diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
index e70da1c..b464b8b 100644
--- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
@@ -21,6 +21,7 @@
         MLIROptLib
 
         GmlStPasses
+        GmlStTestPasses
         GmlStDialect
         LmhloDialect
         LmhloGPUDialect
diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
index 15caee9..c3a1c64 100644
--- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
+++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h"
 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
 #include "mlir-hlo/Dialect/lhlo/transforms/register_passes.h"
 #include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
@@ -31,6 +32,7 @@
   mlir::lmhlo::registerAllLmhloPasses();
   mlir::hlo::registerAllHloPasses();
   mlir::gml_st::registerGmlStPasses();
+  mlir::gml_st::registerGmlStTestPasses();
 
   mlir::DialectRegistry registry;
   mlir::registerAllDialects(registry);