blob: e85d1afcbbab369c680dfd6308c2746530aa322b [file] [log] [blame]
//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a Module level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#define DEBUG_TYPE "toy-shape-inference"
using namespace toy;
using llvm::MutableArrayRef;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
using llvm::StringRef;
using llvm::Twine;
/// Create mangled name for function specialization. We will simply append the
/// shape of the arguments to the function name. For example calling
///
/// "toy.generic_call"(%1, %3) {callee: "foo"}
/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
///
/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
/// have provide a function with a similar name. But we will claim this as a
/// feature: this allow the user to provide custom specialization!
static std::string mangle(StringRef funcName,
MutableArrayRef<mlir::OpOperand> operands) {
std::string mangledName;
mangledName.reserve(funcName.size() + operands.size() * 6);
mangledName = funcName;
for (auto &operand : operands) {
auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
mangledName += "_";
const char *sep = "";
for (auto dim : arrayTy.getShape()) {
mangledName += (sep + Twine(dim)).str();
sep = "x";
}
}
return mangledName;
}
namespace {
/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
/// whole. MLIR also supports FunctionPass which are restricted to modify a
/// single function at a time. This pass couldn't be a function pass due the
/// nature of its interprocedural transformations.
///
/// The algorithm has two levels, first intra-procedurally:
///
/// 1) Build a worklist containing all the operations that are returning
/// a generic Toy array: these are the operations that need shape
/// inference.
/// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist,
/// d) infer the shape of its output from the arguments type.
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
/// return type for the function from the return operation.
///
/// There is a twist though: when a call to a generic function is encountered,
/// shape inference requires the return type of the callee to be inferred first.
/// At this point we need to run specialize the callee by cloning it. Here is
/// the inter-procedural flow:
///
/// 1) Keep a worklist of function to process. Start with function "main".
/// 2) While the worklist isn't empty:
/// a) Take the last inserted function in the worklist.
/// b) Run the intra-procedural shape inference on this function.
/// c) If the intra-procedural shape inference can't complete, it returns
/// a FuncOp that needs to be inferred first. In this case, queue this
/// new function and continue. Otherwise the inference succeeded and we
/// can pop from the queue.
///
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
public:
// One entry in the inter-procedural worklist. It keeps track of the
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
mlir::FuncOp function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
void runOnModule() override {
auto module = getModule();
mlir::ModuleManager moduleManager(module);
auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"shape inference failed: can't find a main function\n");
signalPassFailure();
return;
}
/// Inter-procedural loop, initialize with `main` and iterate till
/// successfully infer the full reachable call-graph from main.
SmallVector<FunctionToSpecialize, 8> worklist;
worklist.push_back({main, "", {}});
while (!worklist.empty()) {
if (failed(specialize(worklist, moduleManager)))
return;
}
// Delete any generic function left
// FIXME: we may want this as a separate pass.
for (mlir::FuncOp function :
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
function.erase();
}
}
}
/// Run inference on a function. If a mangledName is provided, we need to
/// specialize the function: to this end clone it first.
mlir::LogicalResult
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::FuncOp f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
if (moduleManager.lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// FuncOp already specialized, move on.
return mlir::success();
}
// Create a new function with a generic array return type, it will be
// updated when the inference for the function body completes.
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
auto newFunction =
mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
type, f.getDialectAttrs());
moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
f.dump();
llvm::dbgs() << "====== Into : \n";
newFunction.dump();
});
f = newFunction;
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
assert(succeeded(verify(f)));
}
LLVM_DEBUG(llvm::dbgs()
<< "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
signalPassFailure();
return emitError(mlir::UnknownLoc::get(&getContext()),
"Toy dialect is not registered");
}
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
opWorklist.insert(op);
}
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
return !ty.cast<ToyArrayType>().isGeneric();
});
});
if (nextop == opWorklist.end())
break; // failure: no operations can be inferred.
mlir::Operation *op = *nextop;
opWorklist.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
// The add operation is trivial: propagate the input type as is.
if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
op->getResult(0)->setType(op->getOperand(0)->getType());
continue;
}
// Transpose is easy: just invert the dimensions.
if (op->getName().getStringRef() == "toy.transpose") {
SmallVector<int64_t, 2> dims;
auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Multiplication is a bit trickier, handle rank 1 as dot product and rank
// 2 as matrix multiplications.
// We need to be careful about rank mismatch here: the verifier could
// catch it but shape inference earlier in the pass could generate an
// invalid IR (from an invalid Toy input of course) and we wouldn't want
// to crash here.
if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank) {
return op->emitError("shape mismatch: LHS and RHS must have the same "
"rank for multiplication, got ")
<< lhsRank << " vs " << lhsRank;
}
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else {
if (lhsRank != 2) {
return op->emitError("shape mismatch: expect rank 1 or 2 for mul "
"operands, got ")
<< lhsRank;
}
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
}
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Process calls: lookup the callee after mangling the name with the
// argument shapes. If the callee does not exist, we stop the inference
// for this function, queue the callee in the inter-procedural work list,
// and return. The current function stays in the work list and will
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) {
signalPassFailure();
return f.emitError("shape inference failed, call to unknown '")
<< calleeName << "'";
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto mangledCallee =
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
funcWorklist.push_back({callee, std::move(mangledName),
llvm::to_vector<4>(op->getOperandTypes())});
return mlir::success();
}
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::OpBuilder builder(f.getBody());
builder.setInsertionPoint(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
if (newCall.getNumResults()) {
op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
op->erase();
continue;
}
}
}
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
signalPassFailure();
auto diag = f.emitError("shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
diag << " - " << *ope << "\n";
return diag;
}
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
f.setType(newType);
assert(succeeded(verify(f)));
break;
}
return mlir::success();
}
};
} // end anonymous namespace
namespace toy {
std::unique_ptr<mlir::Pass> createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}
} // namespace toy