blob: 1e9933993d0c4b3c0edeeca1a168802b401d9fcc [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace mhlo {
namespace {
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
explicit ShapeReificationPattern(MLIRContext *context)
: OpRewritePattern<shape::ShapeOfOp>(context) {
// Recursively reify until we hit an op that doesn't support it.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
// Only reify shape computation if operand allows for it.
auto shapeOrigin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
if (!shapeOrigin) return failure();
llvm::SmallVector<Value, 1> reifications;
if (failed(shapeOrigin.reifyReturnTypeShapes(
rewriter, shapeOrigin->getOperands(), reifications)))
return failure();
assert(reifications.size() == 1);
Value reifiedShape = reifications.front();
// Insert cast if needed.
if (reifiedShape.getType() != op.getType()) {
reifiedShape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
reifiedShape);
}
rewriter.replaceOp(op, reifiedShape);
return success();
}
};
template <typename OpTy>
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find all the shape operands, direct and indirect.
SmallVector<Value, 8> inlinedOperands;
for (Value direct : op->getOperands()) {
if (auto bcastOp = direct.getDefiningOp<shape::BroadcastOp>()) {
for (Value indirect : bcastOp->getOperands())
inlinedOperands.push_back(indirect);
} else {
inlinedOperands.push_back(direct);
}
}
// Only rewrite if it makes a difference.
if (inlinedOperands.size() == op.getNumOperands()) return failure();
// Inline shape operands.
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), inlinedOperands,
op->getAttrs());
return success();
}
};
LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
// Only implemented for single-result ops.
if (op->getNumResults() != 1) return failure();
// Find a preceding `assuming` op.
auto *theBlock = op->getBlock();
Operation *prev = op->getPrevNode();
while (prev != nullptr && !llvm::isa<shape::AssumingOp>(prev))
prev = prev->getPrevNode();
auto assumingOp = llvm::dyn_cast_or_null<shape::AssumingOp>(prev);
if (!assumingOp) return failure();
assert(assumingOp->getBlock() == theBlock && op->getBlock() == theBlock &&
"expect assuming op and root op to be in the same block");
// Make sure that all operands will be available after moving.
auto isAvailable = [&](Value v) {
Operation *def = v.getDefiningOp();
return def == nullptr || def->getBlock() != theBlock ||
!assumingOp->isBeforeInBlock(def);
};
if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
Block *body = assumingOp.getBody();
auto yieldOp = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
// Find the operands to use if the op was within the assuming region. We
// will later use their copies, as we copy the assuming op and its body.
SmallVector<Value, 8> newOperandsUnmapped =
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
for (const auto &result : llvm::enumerate(assumingOp->getResults())) {
if (result.value() == v) return yieldOp->getOperand(result.index());
}
return v;
}));
// Insert the rewritten assuming op right before the old one.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assumingOp);
auto newAssumingOp = rewriter.create<shape::AssumingOp>(
assumingOp.getLoc(), assumingOp.getWitness(),
[&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator())
b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands.
for (auto it : llvm::zip(op->getOperands(), newOperandsUnmapped)) {
Value oldOperand, newOperandUnmapped;
std::tie(oldOperand, newOperandUnmapped) = it;
mapping.map(oldOperand, mapping.lookupOrDefault(newOperandUnmapped));
}
Operation *newOp = b.clone(*op, mapping);
// Yield the previous results and also the new ones.
auto mappedResults = llvm::to_vector<8>(llvm::map_range(
yieldOp.getOperands(),
[&](Value v) { return mapping.lookupOrDefault(v); }));
mappedResults.append(newOp->getResults().begin(),
newOp->getResults().end());
return mappedResults;
});
// Replace the assuming op and the root op with the corresponding result
// values.
ValueRange newAssumingOpResults = newAssumingOp->getResults();
rewriter.replaceOp(assumingOp, newAssumingOpResults.drop_back());
rewriter.replaceOp(op, newAssumingOpResults.back());
return success();
}
/// Move operation into a preceding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
struct MoveUpIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return moveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
}
};
// Move elementwise operations into a preceding assuming op. This will
// eventually allow for more fusion opportunities.
struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations with no
// side effects.
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
return failure();
}
if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter);
}
};
// Move operation into an assuming region if all uses are within its body.
LogicalResult moveDownIntoAssumingOpMatchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
auto users = op->getUsers();
auto it = users.begin();
auto end = users.end();
if (it == end) return failure();
// Find candidate assuming op.
auto assumingOp = (it++)->getParentOfType<shape::AssumingOp>();
if (!assumingOp || assumingOp->isProperAncestor(op)) return failure();
// Make sure all uses are within the unique assuming op's body.
while (it != end) {
auto hopefullySameAssumingOp = (it++)->getParentOfType<shape::AssumingOp>();
if (!hopefullySameAssumingOp || hopefullySameAssumingOp != assumingOp) {
return failure();
}
}
// Move op into the assuming region.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(assumingOp.getBody());
Operation *newOp = rewriter.clone(*op);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
// Move elementwise operations into succeeding assuming regions. This will
// eventually allow for more fusion opportunities.
struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations with no
// side effects.
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
return failure();
}
if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter);
}
};
/// Move operation out of assuming op. This is only valid for
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
/// will eventually allow to make assuming regions' constraints independent from
/// each other.
template <typename OpTy>
struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Must be inside of an assuming op.
auto assumingOp = op->template getParentOfType<shape::AssumingOp>();
if (!assumingOp) return failure();
// Operands must not be defined within the assuming op.
Block *body = assumingOp.getBody();
auto isAvailable = [&](Value v) {
Operation *def = v.getDefiningOp();
return def == nullptr || def->getBlock() != body;
};
if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
// Move op before the assuming region.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assumingOp);
Operation *newOp = rewriter.clone(*op);
rewriter.replaceOp(op, newOp->getResults());
// If the assuming region yields none of the new op's results, these values
// are exclusively used in the assuming op's body. In these cases there is
// no need for further rewrites.
auto isNewOpResult = [newOp](Value v) {
return llvm::is_contained(newOp->getResults(), v);
};
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
if (llvm::none_of(yield_op.getOperands(), isNewOpResult)) return success();
// If the assuming region yields any of the new op's results, these values
// can instead bypass the assuming region. There is no need to yield them
// explicitly as they are assumed to be independent. The assuming op is
// rewritten accordingly.
SmallVector<Value, 2> replacementValues;
auto newAssumingOp = rewriter.create<shape::AssumingOp>(
assumingOp.getLoc(), assumingOp.getWitness(),
[&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (Operation &nested : body->without_terminator()) {
b.clone(nested, mapping);
}
// Collect new yield operands.
SmallVector<Value, 2> newYieldOperands;
for (Value result : yield_op.getOperands()) {
if (isNewOpResult(result)) {
replacementValues.push_back(result);
} else {
newYieldOperands.push_back(mapping.lookupOrDefault(result));
replacementValues.push_back(nullptr);
}
}
return newYieldOperands;
});
// Use the assuming op's results for the missing replacement values.
auto src = newAssumingOp.getResults().begin();
for (auto &dst : replacementValues) {
if (dst) continue;
dst = *src++;
}
rewriter.replaceOp(assumingOp, replacementValues);
return success();
}
};
/// Merge assuming regions if their constraints are independent from each other.
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::AssumingOp op,
PatternRewriter &rewriter) const override {
// Merge assuming op with directly preceding one if both witnesses are
// availiable.
auto precedingOp =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!precedingOp) return failure();
if (op.getWitness().getDefiningOp() == precedingOp) return failure();
// Merge witnesses.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(precedingOp);
Value newWitness = rewriter.create<shape::AssumingAllOp>(
op.getWitness().getDefiningOp()->getLoc(),
ValueRange{precedingOp.getWitness(), op.getWitness()});
// Merge assuming ops.
Block *body_a = precedingOp.getBody();
Block *body_b = op.getBody();
auto newAssumingOp = rewriter.create<shape::AssumingOp>(
precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) {
// Copy preceding op's body.
BlockAndValueMapping mapping;
for (auto &nested : body_a->without_terminator()) {
b.clone(nested, mapping);
}
// Map result values of preceding assuming op.
auto yieldOpA =
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
for (auto pair :
llvm::zip(precedingOp->getResults(), yieldOpA.getOperands())) {
mapping.map(std::get<0>(pair),
mapping.lookupOrDefault(std::get<1>(pair)));
}
// Copy op's body.
for (auto &nested : body_b->without_terminator()) {
b.clone(nested, mapping);
}
// Collect merged assuming op's results.
SmallVector<Value, 4> mappedResults;
auto yieldOpB =
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
for (Value v : yieldOpA.getOperands()) {
mappedResults.push_back(mapping.lookupOrDefault(v));
}
for (Value v : yieldOpB.getOperands()) {
mappedResults.push_back(mapping.lookupOrDefault(v));
}
return mappedResults;
});
// Replace the two assuming ops with the new corresponding results.
ValueRange newResults = newAssumingOp->getResults();
size_t splitAt = precedingOp->getNumResults();
rewriter.replaceOp(precedingOp, newResults.take_front(splitAt));
rewriter.replaceOp(op, newResults.drop_front(splitAt));
return success();
}
};
struct EliminateDuplicateCstrBroadcastableOps
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
// Search for previous occurence of the same constraint.
Operation *it = op->getPrevNode();
while (it != nullptr) {
if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
if (candidate.getShapes() == op.getShapes()) {
rewriter.replaceOp(op, candidate.getResult());
return success();
}
}
it = it->getPrevNode();
}
return failure();
}
};
struct MergeAssumingOpsPass
: public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns);
GreedyRewriteConfig config;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
// clang-format off
patterns->add<
EliminateDuplicateCstrBroadcastableOps,
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern,
MoveElementwiseOpsDownIntoAssumingOpPattern,
MoveElementwiseOpsUpIntoAssumingOpPattern,
MoveUpIntoAssumingOpPattern<shape::AssumingAllOp>,
MoveUpIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveUpIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveUpOutOfAssumingOpPattern<shape::AssumingAllOp>,
MoveUpOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
ShapeReificationPattern>(context);
// clang-format on
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
context);
mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
}
std::unique_ptr<OperationPass<func::FuncOp>> createMergeAssumingOpsPass() {
return std::make_unique<MergeAssumingOpsPass>();
}
} // namespace mhlo
} // namespace mlir