blob: 0119789a3a1ac3ccb3621210b06e0d988b500eed [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <functional>
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace mhlo {
namespace {
enum class CstrBroadcastableOperandKind {
kValue = 0,
kShapeOfValue = 1,
};
struct CstrBroadcastableOperand {
static CstrBroadcastableOperand valueOf(BlockArgument barg) {
return {CstrBroadcastableOperandKind::kValue, barg};
}
static CstrBroadcastableOperand shapeOf(BlockArgument barg) {
return {CstrBroadcastableOperandKind::kShapeOfValue, barg};
}
// An arbitrary but well define order.
inline bool operator<(const CstrBroadcastableOperand &rhs) const {
if (kind != rhs.kind) return kind < rhs.kind;
return value.getArgNumber() < rhs.value.getArgNumber();
}
inline bool operator>(const CstrBroadcastableOperand &rhs) const {
return rhs < *this;
}
inline bool operator<=(const CstrBroadcastableOperand &rhs) const {
return !(*this > rhs);
}
inline bool operator>=(const CstrBroadcastableOperand &rhs) const {
return !(*this < rhs);
}
// Equality.
inline bool operator==(const CstrBroadcastableOperand &rhs) const {
return kind == rhs.kind && value == rhs.value;
}
inline bool operator!=(const CstrBroadcastableOperand &rhs) const {
return !(*this == rhs);
}
CstrBroadcastableOperandKind kind;
BlockArgument value;
};
struct CstrBroadcastableIntent {
explicit CstrBroadcastableIntent(Location loc) : loc(loc) {}
// A well defined order that sorts weaker constraints to the front.
inline bool operator<(const CstrBroadcastableIntent &rhs) const {
// Sort weaker constraints to the front.
if (operands.size() != rhs.operands.size())
return operands.size() < rhs.operands.size();
return operands < rhs.operands;
}
inline bool operator>(const CstrBroadcastableIntent &rhs) const {
return rhs < *this;
}
inline bool operator<=(const CstrBroadcastableIntent &rhs) const {
return !(*this > rhs);
}
inline bool operator>=(const CstrBroadcastableIntent &rhs) const {
return !(*this < rhs);
}
inline bool operator==(const CstrBroadcastableIntent &rhs) const {
return operands == rhs.operands;
}
inline bool operator!=(const CstrBroadcastableIntent &rhs) const {
return !(*this == rhs);
}
Location loc;
SmallVector<CstrBroadcastableOperand> operands;
};
void canonicalizeBroadcastabilityCstrs(
SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
// Sort inner constraint arguments and eliminate duplicates.
for (auto &it : broadcastabilityCstrs) {
llvm::sort(it.operands);
auto *newEnd =
llvm::unique(it.operands, [](auto a, auto b) { return a == b; });
it.operands.erase(newEnd, it.operands.end());
}
// Sort broadcastability constraints and sort the strongest to the front.
llvm::sort(broadcastabilityCstrs, std::greater<>());
// Remove broadcastability constraints if they are implied by stronger
// constraints.
for (int i = 0; i < broadcastabilityCstrs.size(); i++) {
CstrBroadcastableIntent &strongCstr = broadcastabilityCstrs[i];
auto *newEnd = std::remove_if(
broadcastabilityCstrs.begin() + i + 1, broadcastabilityCstrs.end(),
[strongCstr](CstrBroadcastableIntent weakerCstr) {
assert(weakerCstr.operands.size() <= strongCstr.operands.size() &&
"only look at possibly weaker broadcastability constraints");
return std::includes(
strongCstr.operands.begin(), strongCstr.operands.end(),
weakerCstr.operands.begin(), weakerCstr.operands.end());
});
broadcastabilityCstrs.erase(newEnd, broadcastabilityCstrs.end());
}
}
void eliminateDuplicateBlockArguments(SmallVector<BlockArgument> &bargs) {
llvm::sort(bargs, [](auto a, auto b) {
return a.getArgNumber() < b.getArgNumber();
});
auto *newEnd = llvm::unique(bargs, [](auto a, auto b) { return a == b; });
bargs.erase(newEnd, bargs.end());
}
void inlineAssumingRegions(Block *theBlock) {
theBlock->walk([](shape::AssumingOp aop) {
Block *body = aop.getBody();
auto yop = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
aop->getBlock()->getOperations().splice(aop->getIterator(),
body->getOperations());
aop.replaceAllUsesWith(yop.getOperands());
yop.erase();
aop.erase();
});
}
Value materializeFusedConstraints(
Location loc, OpBuilder &builder, SmallVector<BlockArgument> &argumentCstrs,
SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
// Ensure to materialize shape_of only once.
DenseMap<Value, Value> shapeOfMaterializations;
auto getShapeOfMaterialization = [&](Value arg) {
auto it = shapeOfMaterializations.find(arg);
if (it != shapeOfMaterializations.end()) return it->second;
auto shapeOf = builder.create<shape::ShapeOfOp>(loc, arg).getResult();
shapeOfMaterializations[arg] = shapeOf;
return shapeOf;
};
SmallVector<Value> witnesses;
witnesses.reserve(argumentCstrs.size() + broadcastabilityCstrs.size());
// Carry over the argument witnesses.
for (BlockArgument it : argumentCstrs) witnesses.push_back(it);
// Materialize broadcastability constraints.
for (const CstrBroadcastableIntent &it : broadcastabilityCstrs) {
auto shapes = llvm::to_vector<8>(llvm::map_range(
it.operands,
[getShapeOfMaterialization](const CstrBroadcastableOperand &operand) {
if (operand.kind == CstrBroadcastableOperandKind::kShapeOfValue) {
return getShapeOfMaterialization(operand.value);
}
assert(operand.kind == CstrBroadcastableOperandKind::kValue);
Value shape = operand.value;
return shape;
}));
auto cstr = builder.create<shape::CstrBroadcastableOp>(it.loc, shapes);
witnesses.push_back(cstr);
}
if (witnesses.size() == 1) return witnesses.front();
return builder.create<shape::AssumingAllOp>(loc, witnesses);
}
void materializeBlockGlobalConstraintFusion(
Location loc, OpBuilder &builder, Block *theBlock,
llvm::SmallSetVector<Operation *, 16> &toBeErased,
SmallVector<BlockArgument> &argumentCstrs,
SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
// Eliminate the old assuming regions and inline their ops into the main
// function body.
inlineAssumingRegions(theBlock);
// Delete ops that are known to have become redundant by inlining of assuming
// regions.
for (auto *it : toBeErased) it->erase();
// Materialize fused constraints at the beginning of the function.
builder.setInsertionPointToStart(theBlock);
Value fusedCstr = materializeFusedConstraints(loc, builder, argumentCstrs,
broadcastabilityCstrs);
// Create fused assuming region with empty body.
Operation *theBlockTerminator = theBlock->getTerminator();
auto fusedAop = builder.create<shape::AssumingOp>(
loc, theBlockTerminator->getOperandTypes(), fusedCstr);
auto *fusedAopBody = new Block;
fusedAop.getDoRegion().getBlocks().push_back(fusedAopBody);
// Splice all the original block's operations into the fused assuming region's
// body (except for the block terminator).
auto &dstBlocks = fusedAopBody->getOperations();
dstBlocks.splice(dstBlocks.begin(), theBlock->getOperations(),
builder.getInsertionPoint(),
theBlockTerminator->getIterator());
// Yield results from the assuming region and pass them on to the original
// block terminator.
builder.setInsertionPointToEnd(fusedAopBody);
builder.create<shape::AssumingYieldOp>(loc,
theBlockTerminator->getOperands());
theBlockTerminator->setOperands(fusedAop.getResults());
}
bool isRemainingUse(OpOperand &use, Block *theBlock,
llvm::SmallSetVector<Operation *, 16> &considerDead) {
Operation *op = use.getOwner();
// Not a real use if user is considered dead.
if (considerDead.count(op)) return false;
// Assuming regions in the regarded block are not a real use as they will be
// inlined.
if (auto aop = llvm::dyn_cast<shape::AssumingOp>(op))
return aop->getBlock() == theBlock;
// Look through assuming regions' yield ops.
if (auto yop = llvm::dyn_cast<shape::AssumingYieldOp>(op)) {
auto aop = yop->getParentOfType<shape::AssumingOp>();
auto outerResult = aop.getResults()[use.getOperandNumber()];
return llvm::all_of(outerResult.getUses(), [&](auto &outerUse) {
return isRemainingUse(outerUse, theBlock, considerDead);
});
}
// Otherwise, consider it a real use.
return true;
}
void tryFlagForErase(Block *theBlock, Operation *op,
llvm::SmallSetVector<Operation *, 16> &toBeErased) {
if (llvm::none_of(op->getUses(), [&](auto &use) {
return isRemainingUse(use, theBlock, toBeErased);
})) {
toBeErased.insert(op);
}
}
bool isWithinBlock(Operation *op, Block *theBlock) {
while (op != nullptr && op->getBlock() != theBlock) op = op->getParentOp();
return op != nullptr;
}
LogicalResult analyzeBroadcastableConstraint(
shape::CstrBroadcastableOp cstrBcastable, Block *theBlock,
llvm::SmallSetVector<Operation *, 16> &toBeErased,
SmallVector<CstrBroadcastableOperand> &transitiveBcastableCstrOperands) {
SmallVector<Value> worklist = cstrBcastable.getShapes();
while (!worklist.empty()) {
Value shape = worklist.pop_back_val();
Operation *def = shape.getDefiningOp();
// For shapes without a definition, expect them to be an argument of the
// regarded block.
if (def == nullptr) {
auto barg = shape.dyn_cast<BlockArgument>();
if (!barg || barg.getParentBlock() != theBlock) return failure();
transitiveBcastableCstrOperands.push_back(
CstrBroadcastableOperand::valueOf(barg));
continue;
}
// For shape_of ops, expect them to wrap an argument of the regarded block.
// The shape reification pass helps achieve this, which should be run before
// this pass.
if (auto sof = llvm::dyn_cast<shape::ShapeOfOp>(def)) {
if (!isWithinBlock(sof, theBlock)) return failure();
tryFlagForErase(theBlock, def, toBeErased);
auto barg = sof.getArg().dyn_cast<BlockArgument>();
if (!barg) return failure();
transitiveBcastableCstrOperands.push_back(
CstrBroadcastableOperand::shapeOf(barg));
continue;
}
// For broadcast ops, broadcastability of the operands is an implicit
// requirement. We can online the operands.
if (auto bcast = llvm::dyn_cast<shape::BroadcastOp>(def)) {
if (!isWithinBlock(bcast, theBlock)) return failure();
tryFlagForErase(theBlock, def, toBeErased);
auto bcastShapes = bcast.getShapes();
worklist.append(bcastShapes.begin(), bcastShapes.end());
continue;
}
// Look into assuming ops to proceed.
if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
if (!isWithinBlock(aop, theBlock)) return failure();
auto yieldOp =
llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
size_t i = llvm::find(aop.getResults(), shape).getIndex();
Value innerShape = yieldOp.getOperand(i);
worklist.push_back(innerShape);
continue;
}
// Otherwise, bail.
return failure();
}
return success();
}
LogicalResult analyzeBlockGlobalConstraints(
Block *theBlock, llvm::SmallSetVector<Operation *, 16> &toBeErased,
SmallVector<BlockArgument> &argumentCstrs,
SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
// Find all the assuming regions and start the search for reachable
// constraints from there.
SmallVector<Value> cstrWorklist;
theBlock->walk(
[&](shape::AssumingOp aop) { cstrWorklist.push_back(aop.getWitness()); });
while (!cstrWorklist.empty()) {
Value cstr = cstrWorklist.pop_back_val();
Operation *def = cstr.getDefiningOp();
// For witnesses without a definition, expect it to be an argument of the
// regarded block.
if (def == nullptr) {
auto barg = cstr.dyn_cast<BlockArgument>();
if (!barg || barg.getParentBlock() != theBlock) return failure();
argumentCstrs.push_back(barg);
continue;
}
// For conjunctions, continue with the operands.
if (auto aaop = llvm::dyn_cast<shape::AssumingAllOp>(def)) {
if (!isWithinBlock(aaop, theBlock)) return failure();
tryFlagForErase(theBlock, def, toBeErased);
auto aaopCstrs = aaop.getOperands();
cstrWorklist.append(aaopCstrs.begin(), aaopCstrs.end());
continue;
}
// For broadcastable constraints, find the transitively included shape
// operands.
if (auto cstrBcastable = llvm::dyn_cast<shape::CstrBroadcastableOp>(def)) {
if (!isWithinBlock(cstrBcastable, theBlock)) return failure();
tryFlagForErase(theBlock, def, toBeErased);
CstrBroadcastableIntent bcastableIntent(cstrBcastable.getLoc());
if (failed(analyzeBroadcastableConstraint(
cstrBcastable, theBlock, toBeErased, bcastableIntent.operands))) {
return failure();
}
broadcastabilityCstrs.push_back(bcastableIntent);
continue;
}
// Look into assuming regions when running into them. They will be inlined
// later.
if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
if (!isWithinBlock(aop, theBlock)) return failure();
size_t i = llvm::find(aop.getResults(), cstr).getIndex();
auto yieldOp =
llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
cstrWorklist.push_back(yieldOp.getOperand(i));
continue;
}
// Otherwise, bail.
return failure();
}
return success();
}
LogicalResult fuseBlockGlobalConstraints(Location loc, OpBuilder &builder,
Block *theBlock) {
// Analyze block-global constraints.
SmallVector<BlockArgument> argumentCstrs;
SmallVector<CstrBroadcastableIntent> broadcastabilityCstrs;
llvm::SmallSetVector<Operation *, 16> toBeErased;
if (failed(analyzeBlockGlobalConstraints(theBlock, toBeErased, argumentCstrs,
broadcastabilityCstrs))) {
return failure();
}
// Return early if there is nothing to do.
if (argumentCstrs.empty() && broadcastabilityCstrs.empty()) {
return success();
}
// Simplify constraints.
eliminateDuplicateBlockArguments(argumentCstrs);
canonicalizeBroadcastabilityCstrs(broadcastabilityCstrs);
// Materialize constraint fusion.
materializeBlockGlobalConstraintFusion(loc, builder, theBlock, toBeErased,
argumentCstrs, broadcastabilityCstrs);
return success();
}
struct ConstraintFusionPass
: public ConstraintFusionPassBase<ConstraintFusionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}
void runOnOperation() override {
func::FuncOp f = getOperation();
auto loc = f.getLoc();
OpBuilder builder(&getContext());
for (auto &block : f.getBody().getBlocks()) {
if (failed(fuseBlockGlobalConstraints(loc, builder, &block)))
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> CreateConstraintFusionPass() {
return std::make_unique<ConstraintFusionPass>();
}
} // namespace mhlo
} // namespace mlir