blob: c9a9fad7bc7eaac8f945734e33ecb4d3bd1ea9f2 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h"
#include <string>
#include <utility>
#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace gml_st {
namespace {
#define GEN_PASS_CLASSES
#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc"
static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__";
static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
/// Peel LoopOps, i.e., split them into two loops: One loop where the
/// `idx`-th loop contains only "full" iterations and a second loop for the
/// remaining partial iteration (if any).
struct TiledLoopPeelingPattern : public OpRewritePattern<LoopOp> {
TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skip_partial)
: OpRewritePattern<LoopOp>(ctx), idx(idx), skip_partial(skip_partial) {}
LogicalResult matchAndRewrite(LoopOp loopOp,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> peeledLoops;
if (loopOp->hasAttr(kPeeledLoopsLabel)) {
auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>();
peeledLoops =
llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt();
}));
// Check if the loop was already peeled.
if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure();
}
if (skip_partial && loopOp->hasAttr(kPartialIterationLabel))
// No peeling of loop nests with a partial iteration.
return failure();
if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx)
return failure();
// Peel loop and canonicalize.
LoopOp result;
if (failed(peelAndCanonicalizeGmlStLoop(rewriter, loopOp, idx, result)))
return failure();
// Apply label, so that the same loop is not rewritten a second time.
peeledLoops.push_back(idx);
rewriter.updateRootInPlace(loopOp, [&]() {
loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
});
result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops));
result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
return success();
}
/// Index of loop to peel.
int64_t idx;
/// If set to true, do not peel LoopOps with a partial iteration.
bool skip_partial;
};
class TestGmlStLoopPeelingPass
: public TestGmlStLoopPeelingBase<TestGmlStLoopPeelingPass> {
void runOnOperation() final {
auto funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
for (unsigned idx : dims)
patterns.add<TiledLoopPeelingPattern>(ctx, idx, skip_partial);
(void)(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)));
// Drop the markers.
funcOp.walk([](LoopOp op) {
op->removeAttr(kPeeledLoopsLabel);
op->removeAttr(kPartialIterationLabel);
});
}
};
struct LinalgTilingPattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
LinalgTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter f,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
filter(std::move(f)),
options(std::move(options)) {}
LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, op))) return failure();
FailureOr<linalg::TiledLinalgOp> res =
gml_st::tileLinalgOp(rewriter, op, options);
if (failed(res)) return failure();
filter.replaceLinalgTransformationFilter(rewriter, res->op);
if (res->tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, res->tensorResults);
return success();
}
private:
linalg::LinalgTransformationFilter filter;
linalg::LinalgTilingOptions options;
};
struct TestGmlStLoopTilingPass
: public TestGmlStLoopTilingBase<TestGmlStLoopTilingPass> {
TestGmlStLoopTilingPass() = default;
TestGmlStLoopTilingPass(ArrayRef<int64_t> tileSizes,
ArrayRef<StringRef> distributionTypes) {
this->tile_sizes = tileSizes;
this->distribution_types = llvm::to_vector<2>(llvm::map_range(
distributionTypes, [](StringRef ref) { return ref.str(); }));
}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
auto distTypes = llvm::to_vector<2>(llvm::map_range(
distribution_types, [](std::string &str) { return StringRef(str); }));
auto options = linalg::LinalgTilingOptions()
.setTileSizes(tile_sizes)
.setDistributionTypes(distTypes);
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
linalg::LinalgTransformationFilter f(ArrayRef<StringAttr>{},
StringAttr::get(ctx, "tile"));
patterns.add<LinalgTilingPattern>(ctx, options, f);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
funcOp.walk([](linalg::LinalgOp op) {
op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
});
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createTestGmlStLoopPeelingPass() {
return std::make_unique<TestGmlStLoopPeelingPass>();
}
std::unique_ptr<OperationPass<func::FuncOp>> createTestGmlStLoopTilingPass() {
return std::make_unique<TestGmlStLoopTilingPass>();
}
} // namespace gml_st
} // namespace mlir