[ST] Add loop peeling pass for gml_st.loop.
PiperOrigin-RevId: 426982579
Change-Id: I10691615a84d8d66ea3ca5dbd89d44eef0182ebb
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index cd0b431..1300d73 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -1567,6 +1567,38 @@
)
gentbl_cc_library(
+ name = "gml_st_test_passes_inc_gen",
+ compatible_with = get_compatible_with_cloud(),
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "-gen-pass-decls",
+ "-name=GmlStTest",
+ ],
+ "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td",
+ deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
+cc_library(
+ name = "all_test_passes",
+ srcs = ["lib/Dialect/gml_st/transforms/test_passes.cc"],
+ hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h"],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":gml_st_test_passes_inc_gen",
+ ":gml_st_transforms",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+gentbl_cc_library(
name = "transforms_pass_inc_gen",
compatible_with = get_compatible_with_cloud(),
strip_include_prefix = "include",
@@ -1766,11 +1798,10 @@
cc_binary(
name = "mlir-hlo-opt",
- srcs = [
- "tools/mlir-hlo-opt/mlir-hlo-opt.cpp",
- ],
+ srcs = ["tools/mlir-hlo-opt/mlir-hlo-opt.cpp"],
deps = [
":all_passes",
+ ":all_test_passes",
":gml_st",
":hlo_dialect_registration",
":lhlo",
@@ -2017,3 +2048,17 @@
"@llvm-project//mlir:Transforms",
],
)
+
+cc_library(
+ name = "gml_st_transforms",
+ srcs = ["lib/Dialect/gml_st/transforms/transforms.cc"],
+ hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/transforms.h"],
+ includes = ["include"],
+ deps = [
+ ":gml_st",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SCFUtils",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
index 44c079e..f6a9948 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt
@@ -18,6 +18,10 @@
mlir_tablegen(passes.h.inc -gen-pass-decls -name GmlSt)
add_public_tablegen_target(MLIRGmlStPassIncGen)
+set(LLVM_TARGET_DEFINITIONS test_passes.td)
+mlir_tablegen(test_passes.h.inc -gen-pass-decls -name GmlStTest)
+add_public_tablegen_target(MLIRGmlStTestPassIncGen)
+
set(LLVM_TARGET_DEFINITIONS tiling_interface.td)
mlir_tablegen(tiling_interface.h.inc -gen-op-interface-decls)
mlir_tablegen(tiling_interface.cc.inc -gen-op-interface-defs)
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h
new file mode 100644
index 0000000..b4dc704
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h
@@ -0,0 +1,32 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace gml_st {
+
+std::unique_ptr<OperationPass<FuncOp>> createTestGmlStLoopPeelingPass();
+
+#define GEN_PASS_REGISTRATION
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc"
+
+} // namespace gml_st
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H_
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td
new file mode 100644
index 0000000..24f5892
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td
@@ -0,0 +1,27 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/Pass/PassBase.td"
+
+def TestGmlStLoopPeeling : Pass<"test-gml-st-loop-peeling", "mlir::FuncOp"> {
+ let summary = "Peel `gml_st.loop`";
+ let constructor = "::mlir::gml_st::createTestGmlStLoopPeelingPass()";
+ let options = [
+ Option<"skip_partial", "skip-partial", "bool", /*default=*/"false",
+ "Skip loops inside partial iterations during peeling">,
+ ListOption<"dims", "dims", "unsigned", "Dimensions to peel",
+ "llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+ ];
+}
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h
new file mode 100644
index 0000000..f71e44e
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h
@@ -0,0 +1,54 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace gml_st {
+
+/// Rewrite a gml_st::LoopOp with bounds/step that potentially do not divide
+/// evenly into a gml_st::LoopOp where the step divides the iteration space
+/// evenly, followed by another gml_st::LoopOp for the last (partial) iteration
+/// (if any). This transformation is called "loop peeling".
+///
+/// This function peels the `idx`-th loop of the gml_st::LoopOp. To tile all
+/// loops in the loop nest, this function must be called multiple times.
+///
+/// After loop peeling, this function tries to simplify/canonicalize affine.min
+/// and affine.max ops in the body of the two gml_st::LoopOps. For more details,
+/// refer to `mlir::scf::peelAndCanonicalizeForLoop`.
+///
+/// The return value indicates whether the loop was rewritten or not. Loops are
+/// not rewritten if:
+/// * Loop step size is 1 or
+/// * Loop bounds and step size are static, and step already divides the
+/// iteration space evenly.
+///
+/// Note: This function rewrites the given gml_st::LoopOp in-place and clones
+/// the gml_st::LoopOp operation for the last iteration. It replaces all uses of
+/// the unpeeled gml_st::LoopOp with the results of the newly generated
+/// gml_st::LoopOp.
+LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
+ LoopOp loopOp, int64_t idx,
+ LoopOp &result);
+
+} // namespace gml_st
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H_
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
index 9352d40..9484050 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
@@ -1,4 +1,3 @@
-#
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -45,3 +44,31 @@
MLIRPass
MLIRSupport
)
+
+add_mlir_library(GmlStTransforms
+ transforms.cc
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRAffine
+ MLIRDialectUtils
+ MLIRIR
+ MLIRSCFUtils
+)
+
+add_mlir_library(GmlStTestPasses
+ test_passes.cc
+
+ DEPENDS
+ MLIRGmlStTestPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ GmlStTransforms
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc
new file mode 100644
index 0000000..1a83b92
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/test_passes.cc
@@ -0,0 +1,109 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h"
+
+#include <utility>
+
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc"
+
+static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
+static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
+
+/// Peel LoopOps, i.e., split them into two loops: One loop where the
+/// `idx`-th loop contains only "full" iterations and a second loop for the
+/// remaining partial iteration (if any).
+struct TiledLoopPeelingPattern : public OpRewritePattern<LoopOp> {
+ TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skip_partial)
+ : OpRewritePattern<LoopOp>(ctx), idx(idx), skip_partial(skip_partial) {}
+
+ LogicalResult matchAndRewrite(LoopOp loopOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<int64_t> peeledLoops;
+ if (loopOp->hasAttr(kPeeledLoopsLabel)) {
+ auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
+ peeledLoops =
+ llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt();
+ }));
+ // Check if the loop was already peeled.
+ if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure();
+ }
+ if (skip_partial && loopOp->hasAttr(kPartialIterationLabel))
+ // No peeling of loop nests with a partial iteration.
+ return failure();
+
+ if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
+ return failure();
+
+ // Peel loop and canonicalize.
+ LoopOp result;
+ if (failed(peelAndCanonicalizeGmlStLoop(rewriter, loopOp, idx, result)))
+ return failure();
+
+ // Apply label, so that the same loop is not rewritten a second time.
+ peeledLoops.push_back(idx);
+ rewriter.updateRootInPlace(loopOp, [&]() {
+ loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
+ });
+ result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
+ result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+
+ return success();
+ }
+
+ /// Index of loop to peel.
+ int64_t idx;
+
+ /// If set to true, do not peel LoopOps with a partial iteration.
+ bool skip_partial;
+};
+
+class TestGmlStLoopPeelingPass
+ : public TestGmlStLoopPeelingBase<TestGmlStLoopPeelingPass> {
+ void runOnOperation() final {
+ auto funcOp = getOperation();
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet patterns(ctx);
+ for (unsigned idx : dims)
+ patterns.add<TiledLoopPeelingPattern>(ctx, idx, skip_partial);
+
+ (void)(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)));
+
+ // Drop the markers.
+ funcOp.walk([](LoopOp op) {
+ op->removeAttr(kPeeledLoopsLabel);
+ op->removeAttr(kPartialIterationLabel);
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createTestGmlStLoopPeelingPass() {
+ return std::make_unique<TestGmlStLoopPeelingPass>();
+}
+
+} // namespace gml_st
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc
new file mode 100644
index 0000000..7f16b0a
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc
@@ -0,0 +1,136 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+/// Rewrite a LoopOp with bounds/step that potentially do not divide evenly
+/// into two LoopOps: One where the step divides the iteration space
+/// evenly, followed another one for the last (partial) iteration (if any). This
+/// function only rewrites the `idx`-th loop of the loop nest represented by
+/// the LoopOp. To peel the entire loop nest, this function must be called
+/// multiple times.
+///
+/// This function rewrites the given LoopOp in-place and creates a new
+/// LoopOp for the last iteration. It replaces all uses of the original
+/// LoopOp with the results of the newly generated one.
+///
+/// The newly generated LoopOp is returned via `result`. The boundary
+/// at which the loop is split (new upper bound) is returned via `splitBound`.
+/// The return value indicates whether the LoopOp was rewritten or not.
+static LogicalResult peelLoop(RewriterBase &b, LoopOp loopOp, int64_t idx,
+ LoopOp &result, Value &splitBound) {
+ Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
+ step = loopOp.step()[idx];
+ auto ubInt = getConstantIntValue(ub);
+
+ auto loc = loopOp.getLoc();
+ AffineExpr exprLb, exprUb, exprStep;
+ bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
+ // New upper bound: %ub - (%ub - %lb) mod %step
+ auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
+ SmallVector<Value> operands{lb, ub, step};
+ canonicalizeMapAndOperands(&modMap, &operands);
+ modMap = simplifyAffineMap(modMap);
+ RewriterBase::InsertionGuard guard(b);
+ b.setInsertionPoint(loopOp);
+ splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
+ // No specialization necessary if step already divides upper bound evenly.
+ if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
+ return failure();
+
+ // Create remainder loop.
+ b.setInsertionPointAfter(loopOp);
+ auto remainderLoop = cast<LoopOp>(b.clone(*loopOp.getOperation()));
+ loopOp.replaceAllUsesWith(remainderLoop->getResults());
+ // Outputs: Take tensors from main loop's results. Take memrefs from main
+ // loop's outputs.
+ SmallVector<Value> remainderOutputs;
+ for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
+ remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
+ ? loopOp.outputs()[o]
+ : loopOp->getResult(t++));
+ }
+ remainderLoop.outputsMutable().assign(remainderOutputs);
+
+ // Set new loop bounds.
+ b.updateRootInPlace(loopOp, [&]() {
+ SmallVector<Value> ubs = loopOp.upperBound();
+ ubs[idx] = splitBound;
+ loopOp.upperBoundMutable().assign(ubs);
+ });
+ SmallVector<Value> lbs = remainderLoop.lowerBound();
+ lbs[idx] = splitBound;
+ remainderLoop.lowerBoundMutable().assign(lbs);
+
+ result = remainderLoop;
+ return success();
+}
+
+template <typename OpTy, bool IsMin>
+static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, LoopOp mainLoop,
+ LoopOp remainderLoop, Value mainIv,
+ Value remainderIv, Value ub,
+ Value step) {
+ mainLoop.walk([&](OpTy affineOp) {
+ AffineMap map = affineOp.getAffineMap();
+ (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
+ affineOp.operands(), IsMin, mainIv, ub,
+ step, /*insideLoop=*/true);
+ });
+ remainderLoop.walk([&](OpTy affineOp) {
+ AffineMap map = affineOp.getAffineMap();
+ (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
+ affineOp.operands(), IsMin, remainderIv,
+ ub, step, /*insideLoop=*/false);
+ });
+}
+
+} // namespace
+
+LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
+ LoopOp loopOp, int64_t idx,
+ LoopOp &result) {
+ int64_t numLoops = loopOp.iterator_types().size();
+ if (idx < 0 || numLoops <= idx) return failure();
+
+ Value ub = loopOp.upperBound()[idx];
+ LoopOp remainderLoop;
+ Value splitBound;
+ if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
+ return failure();
+
+ // Rewrite affine.min and affine.max ops.
+ Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
+ remainderIv = remainderLoop.getInductionVars()[idx];
+
+ rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
+ rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
+ rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
+ rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
+
+ result = remainderLoop;
+ return success();
+}
+
+} // namespace gml_st
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir
new file mode 100644
index 0000000..b85afba
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/loop_peeling.mlir
@@ -0,0 +1,231 @@
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012
+// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2 skip-partial" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012-SKIP-PARTIAL
+
+// CHECK-TILE-2-LABEL: func @loop_3d_tensor(
+// CHECK-TILE-2-SAME: %[[input:.*]]: tensor<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-TILE-2-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-TILE-2-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-TILE-2-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-TILE-2: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
+// CHECK-TILE-2: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
+// CHECK-TILE-2: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
+// CHECK-TILE-2: %[[init_tensor:.*]] = linalg.init_tensor
+// CHECK-TILE-2: %[[split_bound:.*]] = affine.apply
+// CHECK-TILE-2: %[[r1:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
+// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[split_bound]])
+// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]])
+// CHECK-TILE-2-SAME: ins (%[[loop_in1:.*]] = %[[input]]: tensor<?x?x?xf32>)
+// CHECK-TILE-2-SAME: outs (%[[loop_out1:.*]] = %[[init_tensor]]: tensor<?x?x?xf32>) {
+// CHECK-TILE-2: %[[min0_1:.*]] = affine.min
+// CHECK-TILE-2: %[[min1_1:.*]] = affine.min
+// CHECK-TILE-2: %[[in_slice1:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+// CHECK-TILE-2: %[[out_slice1:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+// CHECK-TILE-2: %[[mod_slice1:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+// CHECK-TILE-2: gml_st.yield %[[mod_slice1]]
+// CHECK-TILE-2: %[[r2:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
+// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[dim2]])
+// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]])
+// CHECK-TILE-2-SAME: ins (%[[loop_in2:.*]] = %[[input]]: tensor<?x?x?xf32>)
+// CHECK-TILE-2-SAME: outs (%[[loop_out2:.*]] = %[[r1]]: tensor<?x?x?xf32>) {
+// CHECK-TILE-2: %[[min0_2:.*]] = affine.min
+// CHECK-TILE-2: %[[min1_2:.*]] = affine.min
+// CHECK-TILE-2: %[[apply2:.*]] = affine.apply
+// CHECK-TILE-2: %[[in_slice2:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+// CHECK-TILE-2: %[[out_slice2:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+// CHECK-TILE-2: %[[mod_slice2:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+// CHECK-TILE-2: gml_st.yield %[[mod_slice2]]
+// CHECK-TILE-2: return %[[r2]]
+
+// CHECK-TILE-012-LABEL: func @loop_3d_tensor
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012: gml_st.loop {{.*}} {
+// CHECK-TILE-012: gml_st.yield
+// CHECK-TILE-012: }
+// CHECK-TILE-012-NOT: gml_st.loop
+
+// CHECK-TILE-012-SKIP-PARTIAL: func @loop_3d_tensor(
+// CHECK-TILE-012-SKIP-PARTIAL-SAME: %[[input:.*]]: tensor<?x?x?xf32>
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]]
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]]
+// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]]
+// CHECK-TILE-012-SKIP-PARTIAL: %[[p0:.*]] = affine.apply #{{.*}}()[%[[dim0]]
+// CHECK-TILE-012-SKIP-PARTIAL: %[[p1:.*]] = affine.apply #{{.*}}()[%[[dim1]]
+// CHECK-TILE-012-SKIP-PARTIAL: %[[p2:.*]] = affine.apply #{{.*}}()[%[[dim2]]
+// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[c0]]) to (%[[p0]], %[[p1]], %[[p2]])
+// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[p2]]) to (%[[p0]], %[[p1]], %[[dim2]])
+// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[p1]], %[[c0]]) to (%[[p0]], %[[dim1]], %[[dim2]])
+// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[p0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim1]], %[[dim2]])
+func @loop_3d_tensor(%arg0: tensor<?x?x?xf32>, %s0: index, %s1: index,
+ %s2: index) -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+ %result = gml_st.loop
+ (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+ step (%s0, %s1, %s2) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+ outs (%arg5 = %output: tensor<?x?x?xf32>) {
+ %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
+ %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
+ %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
+ %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+ gml_st.yield %updated_slice : tensor<?x?x?xf32>
+ }
+ return %result : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @loop_3d_memref(
+// CHECK-TILE-2-SAME: %[[input:.*]]: memref<?x?x?xf32>, %[[output:.*]]: memref<?x?x?xf32>, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-TILE-2-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-TILE-2-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-TILE-2-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-TILE-2: %[[dim0:.*]] = memref.dim %[[input]], %[[c0]]
+// CHECK-TILE-2: %[[dim1:.*]] = memref.dim %[[input]], %[[c1]]
+// CHECK-TILE-2: %[[dim2:.*]] = memref.dim %[[input]], %[[c2]]
+// CHECK-TILE-2: %[[split_bound:.*]] = affine.apply
+// CHECK-TILE-2: gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]])
+// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[split_bound]])
+// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]])
+// CHECK-TILE-2-SAME: ins (%[[loop_in1:.*]] = %[[input]]: memref<?x?x?xf32>)
+// CHECK-TILE-2-SAME: outs (%[[loop_out1:.*]] = %[[output]]: memref<?x?x?xf32>) {
+// CHECK-TILE-2: %[[min0_1:.*]] = affine.min
+// CHECK-TILE-2: %[[min1_1:.*]] = affine.min
+// CHECK-TILE-2: memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]]
+// CHECK-TILE-2: gml_st.yield
+// CHECK-TILE-2: gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]])
+// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[dim2]])
+// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]])
+// CHECK-TILE-2-SAME: ins (%[[loop_in2:.*]] = %[[input]]: memref<?x?x?xf32>)
+// CHECK-TILE-2-SAME: outs (%[[loop_out2:.*]] = %[[output]]: memref<?x?x?xf32>) {
+// CHECK-TILE-2: %[[min0_2:.*]] = affine.min
+// CHECK-TILE-2: %[[min1_2:.*]] = affine.min
+// CHECK-TILE-2: %[[apply2:.*]] = affine.apply
+// CHECK-TILE-2: memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]]
+// CHECK-TILE-2: gml_st.yield
+// CHECK-TILE-2: return
+
+// CHECK-TILE-012-LABEL: func @loop_3d_memref
+
+!memref_subview_type = type memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>
+
+func @loop_3d_memref(%arg0: memref<?x?x?xf32>, %output: memref<?x?x?xf32>,
+ %s0: index, %s1: index, %s2: index) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+ %dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+ %dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+ gml_st.loop
+ (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+ step (%s0, %s1, %s2) ins (%arg4 = %arg0: memref<?x?x?xf32>)
+ outs (%arg5 = %output : memref<?x?x?xf32>) {
+ %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0]
+ %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1]
+ %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2]
+ %in_slice = memref.subview %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: memref<?x?x?xf32> to !memref_subview_type
+ "computation"(%in_slice) : (!memref_subview_type) -> memref<?x?x?xf32>
+ gml_st.yield
+ }
+ return
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @step_1_do_not_peel
+// CHECK-TILE-2: gml_st.loop
+// CHECK-TILE-2-NOT: gml_st.loop
+
+// CHECK-TILE-012-LABEL: func @step_1_do_not_peel
+
+func @step_1_do_not_peel(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+ %result = gml_st.loop
+ (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2)
+ step (%c1, %c1, %c1) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+ outs (%arg5 = %output: tensor<?x?x?xf32>) {
+ %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+ gml_st.yield %updated_slice : tensor<?x?x?xf32>
+ }
+ return %result : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-TILE-2-LABEL: func @divides_evenly_do_not_peel
+// CHECK-TILE-2: gml_st.loop
+// CHECK-TILE-2-NOT: gml_st.loop
+
+// CHECK-TILE-012-LABEL: func @divides_evenly_do_not_peel
+
+func @divides_evenly_do_not_peel(%arg0: tensor<?x?x?xf32>, %s: index)
+ -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %c64 = arith.constant 64 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor<?x?x?xf32>
+ %result = gml_st.loop
+ (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %c64)
+ step (%s, %s, %c8) ins (%arg4 = %arg0: tensor<?x?x?xf32>)
+ outs (%arg5 = %output: tensor<?x?x?xf32>) {
+ %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %comp = "computation"(%in_slice, %out_slice) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+ gml_st.yield %updated_slice : tensor<?x?x?xf32>
+ }
+ return %result : tensor<?x?x?xf32>
+}
diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
index e70da1c..b464b8b 100644
--- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt
@@ -21,6 +21,7 @@
MLIROptLib
GmlStPasses
+ GmlStTestPasses
GmlStDialect
LmhloDialect
LmhloGPUDialect
diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
index 15caee9..c3a1c64 100644
--- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
+++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
@@ -15,6 +15,7 @@
#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
#include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h"
#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/lhlo/transforms/register_passes.h"
#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
@@ -31,6 +32,7 @@
mlir::lmhlo::registerAllLmhloPasses();
mlir::hlo::registerAllHloPasses();
mlir::gml_st::registerGmlStPasses();
+ mlir::gml_st::registerGmlStTestPasses();
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);