blob: d9829366eed77e19b94df7522092d8855c467d8a [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace mhlo {
namespace {
struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
explicit ShapeReificationPattern(MLIRContext *context)
: OpRewritePattern<shape::ShapeOfOp>(context) {
// Recursively reify until we hit an op that doesn't support it.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
// Only reify shape computation if operand allows for it.
auto shape_origin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
if (!shape_origin) return failure();
llvm::SmallVector<Value, 1> reifications;
if (failed(shape_origin.reifyReturnTypeShapes(
rewriter, shape_origin->getOperands(), reifications)))
return failure();
assert(reifications.size() == 1);
Value reified_shape = reifications.front();
// Insert cast if needed.
if (reified_shape.getType() != op.getType()) {
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
reified_shape);
}
rewriter.replaceOp(op, reified_shape);
return success();
}
};
template <typename OpTy>
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find all the shape operands, direct and indirect.
SmallVector<Value, 8> inlined_operands;
for (Value direct : op->getOperands()) {
if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
for (Value indirect : bcast_op->getOperands())
inlined_operands.push_back(indirect);
} else {
inlined_operands.push_back(direct);
}
}
// Only rewrite if it makes a difference.
if (inlined_operands.size() == op.getNumOperands()) return failure();
// Inline shape operands.
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
inlined_operands, op->getAttrs());
return success();
}
};
bool IsMovable(Operation *op) {
return MemoryEffectOpInterface::hasNoEffect(op) ||
llvm::isa<shape::CstrBroadcastableOp>(op);
}
LogicalResult MoveUpIntoAssumingOpMatchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
// Only implemented for single-result ops.
if (op->getNumResults() != 1) return failure();
// Find a preceding `assuming` op.
auto *the_block = op->getBlock();
Operation *prev = op->getPrevNode();
while (prev != nullptr && !llvm::isa<shape::AssumingOp>(prev))
prev = prev->getPrevNode();
auto assuming_op = llvm::dyn_cast_or_null<shape::AssumingOp>(prev);
if (!assuming_op) return failure();
assert(assuming_op->getBlock() == the_block && op->getBlock() == the_block &&
"expect assuming op and root op to be in the same block");
// Make sure that all operands will be available after moving.
auto is_available = [&](Value v) {
Operation *def = v.getDefiningOp();
return def == nullptr || def->getBlock() != the_block ||
!assuming_op->isBeforeInBlock(def);
};
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
Block *body = assuming_op.getBody();
auto yield_op = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
// Find the operands to use if the op was within the assuming region. We
// will later use their copies, as we copy the assuming op and its body.
SmallVector<Value, 8> new_operands_unmapped =
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
for (const auto &result : llvm::enumerate(assuming_op->getResults())) {
if (result.value() == v) return yield_op->getOperand(result.index());
}
return v;
}));
// Insert the rewritten assuming op right before the old one.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.getWitness(),
[&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator())
b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands.
for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
Value old_operand, new_operand_unmapped;
std::tie(old_operand, new_operand_unmapped) = it;
mapping.map(old_operand,
mapping.lookupOrDefault(new_operand_unmapped));
}
Operation *new_op = b.clone(*op, mapping);
// Yield the previous results and also the new ones.
auto mapped_results = llvm::to_vector<8>(llvm::map_range(
yield_op.getOperands(),
[&](Value v) { return mapping.lookupOrDefault(v); }));
mapped_results.append(new_op->getResults().begin(),
new_op->getResults().end());
return mapped_results;
});
// Replace the assuming op and the root op with the corresponding result
// values.
ValueRange new_assuming_op_results = new_assuming_op->getResults();
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
rewriter.replaceOp(op, new_assuming_op_results.back());
return success();
}
/// Move operation into a preceding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
struct MoveUpIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return MoveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
}
};
// Move elementwise operations into a preceding assuming op. This will
// eventually allow for more fusion opportunities.
struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations with no
// side effects.
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
return failure();
}
if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
return MoveUpIntoAssumingOpMatchAndRewrite(op, rewriter);
}
};
// Move operation into an assuming region if all uses are within its body.
LogicalResult MoveDownIntoAssumingOpMatchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
auto users = op->getUsers();
auto it = users.begin();
auto end = users.end();
if (it == end) return failure();
// Find candidate assuming op.
auto assuming_op = (it++)->getParentOfType<shape::AssumingOp>();
if (!assuming_op || assuming_op->isProperAncestor(op)) return failure();
// Make sure all uses are within the unique assuming op's body.
while (it != end) {
auto hopefully_same_assuming_op =
(it++)->getParentOfType<shape::AssumingOp>();
if (!hopefully_same_assuming_op ||
hopefully_same_assuming_op != assuming_op) {
return failure();
}
}
// Move op into the assuming region.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(assuming_op.getBody());
Operation *new_op = rewriter.clone(*op);
rewriter.replaceOp(op, new_op->getResults());
return success();
}
// Move elementwise operations into succeeding assuming regions. This will
// eventually allow for more fusion opportunities.
struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations with no
// side effects.
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
return failure();
}
if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
return MoveDownIntoAssumingOpMatchAndRewrite(op, rewriter);
}
};
/// Move operation out of assuming op. This is only valid for
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
/// will eventually allow to make assuming regions' constraints independent from
/// each other.
template <typename OpTy>
struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Must be inside of an assuming op.
auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
if (!assuming_op) return failure();
// Operands must not be defined within the assuming op.
Block *body = assuming_op.getBody();
auto is_available = [&](Value v) {
Operation *def = v.getDefiningOp();
return def == nullptr || def->getBlock() != body;
};
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
// Move op before the assuming region.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
Operation *new_op = rewriter.clone(*op);
rewriter.replaceOp(op, new_op->getResults());
// If the assuming region yields none of the new op's results, these values
// are exclusively used in the assuming op's body. In these cases there is
// no need for further rewrites.
auto is_new_op_result = [&](Value v) {
return llvm::is_contained(new_op->getResults(), v);
};
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
if (llvm::none_of(yield_op.getOperands(), is_new_op_result))
return success();
// If the assuming region yields any of the new op's results, these values
// can instead bypass the assuming region. There is no need to yield them
// explicitly as they are assumed to be independent. The assuming op is
// rewritten accordingly.
SmallVector<Value, 2> replacement_values;
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.getWitness(),
[&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (Operation &nested : body->without_terminator()) {
b.clone(nested, mapping);
}
// Collect new yield operands.
SmallVector<Value, 2> new_yield_operands;
for (Value result : yield_op.getOperands()) {
if (is_new_op_result(result)) {
replacement_values.push_back(result);
} else {
new_yield_operands.push_back(mapping.lookupOrDefault(result));
replacement_values.push_back(nullptr);
}
}
return new_yield_operands;
});
// Use the assuming op's results for the missing replacement values.
auto src = new_assuming_op.getResults().begin();
for (auto &dst : replacement_values) {
if (dst) continue;
dst = *src++;
}
rewriter.replaceOp(assuming_op, replacement_values);
return success();
}
};
/// Merge assuming regions if their constraints are independent from each other.
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::AssumingOp op,
PatternRewriter &rewriter) const override {
// Merge assuming op with directly preceding one if both witnesses are
// availiable.
auto preceding_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!preceding_op) return failure();
if (op.getWitness().getDefiningOp() == preceding_op) return failure();
// Merge witnesses.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(preceding_op);
Value new_witness = rewriter.create<shape::AssumingAllOp>(
op.getWitness().getDefiningOp()->getLoc(),
ValueRange{preceding_op.getWitness(), op.getWitness()});
// Merge assuming ops.
Block *body_a = preceding_op.getBody();
Block *body_b = op.getBody();
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
// Copy preceding op's body.
BlockAndValueMapping mapping;
for (auto &nested : body_a->without_terminator()) {
b.clone(nested, mapping);
}
// Map result values of preceding assuming op.
auto yield_op_a =
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
for (auto pair : llvm::zip(preceding_op->getResults(),
yield_op_a.getOperands())) {
mapping.map(std::get<0>(pair),
mapping.lookupOrDefault(std::get<1>(pair)));
}
// Copy op's body.
for (auto &nested : body_b->without_terminator()) {
b.clone(nested, mapping);
}
// Collect merged assuming op's results.
SmallVector<Value, 4> mapped_results;
auto yield_op_b =
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
for (Value v : yield_op_a.getOperands()) {
mapped_results.push_back(mapping.lookupOrDefault(v));
}
for (Value v : yield_op_b.getOperands()) {
mapped_results.push_back(mapping.lookupOrDefault(v));
}
return mapped_results;
});
// Replace the two assuming ops with the new corresponding results.
ValueRange new_results = new_assuming_op->getResults();
size_t split_at = preceding_op->getNumResults();
rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
rewriter.replaceOp(op, new_results.drop_front(split_at));
return success();
}
};
struct EliminateDuplicateCstrBroadcastableOps
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
// Search for previous occurence of the same constraint.
Operation *it = op->getPrevNode();
while (it != nullptr) {
if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
if (candidate.getShapes() == op.getShapes()) {
rewriter.replaceOp(op, candidate.getResult());
return success();
}
}
it = it->getPrevNode();
}
return failure();
}
};
struct MergeAssumingOpsPass
: public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns);
GreedyRewriteConfig config;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
// clang-format off
patterns->add<
EliminateDuplicateCstrBroadcastableOps,
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern,
MoveElementwiseOpsDownIntoAssumingOpPattern,
MoveElementwiseOpsUpIntoAssumingOpPattern,
MoveUpIntoAssumingOpPattern<shape::AssumingAllOp>,
MoveUpIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveUpIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveUpOutOfAssumingOpPattern<shape::AssumingAllOp>,
MoveUpOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
ShapeReificationPattern>(context);
// clang-format on
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
context);
mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
}
std::unique_ptr<OperationPass<func::FuncOp>> createMergeAssumingOpsPass() {
return std::make_unique<MergeAssumingOpsPass>();
}
} // namespace mhlo
} // namespace mlir